From d672b841e03f80ebb1e3b0f6dd87eca646dc44c1 Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Thu, 17 Apr 2025 03:20:46 +0000
Subject: [PATCH 001/123] add jit-friendly dropout w rate in call
---
algoperf/jax_utils.py | 121 ++++++++++++++++++++++++++++++++++++++++++
1 file changed, 121 insertions(+)
create mode 100644 algoperf/jax_utils.py
diff --git a/algoperf/jax_utils.py b/algoperf/jax_utils.py
new file mode 100644
index 000000000..ddafd77c6
--- /dev/null
+++ b/algoperf/jax_utils.py
@@ -0,0 +1,121 @@
+from collections.abc import Sequence
+
+import jax
+import jax.numpy as jnp
+from jax import lax, random
+
+import flax.linen as nn
+from flax.linen.module import Module, compact, merge_param
+from flax.typing import PRNGKey
+
+
+# Custom Layers
+class Dropout(Module):
+ """Create a dropout layer.
+ Forked from https://flax-linen.readthedocs.io/en/latest/_modules/flax/linen/stochastic.html#Dropout.
+ The reference dropout implementation is modified support changes to dropout rate during training by:
+ 1) adding rate argument to the __call__ method
+ 2) removing the if-else condition to check for edge cases, which will trigger a recompile for jitted code
+
+ .. note::
+ When using :meth:`Module.apply() `, make sure
+ to include an RNG seed named ``'dropout'``. Dropout isn't necessary for
+ variable initialization.
+
+ Example usage::
+
+ >>> import flax.linen as nn
+ >>> import jax, jax.numpy as jnp
+
+ >>> class MLP(nn.Module):
+ ... @nn.compact
+ ... def __call__(self, x, train):
+ ... x = nn.Dense(4)(x)
+ ... x = nn.Dropout(0.5, deterministic=not train)(x)
+ ... return x
+
+ >>> model = MLP()
+ >>> x = jnp.ones((1, 3))
+ >>> variables = model.init(jax.random.key(0), x, train=False) # don't use dropout
+ >>> model.apply(variables, x, train=False) # don't use dropout
+ Array([[-0.17875527, 1.6255447 , -1.2431065 , -0.02554005]], dtype=float32)
+ >>> model.apply(variables, x, train=True, rngs={'dropout': jax.random.key(1)}) # use dropout
+ Array([[-0.35751054, 3.2510893 , 0. , 0. ]], dtype=float32)
+
+ Attributes:
+ rate: the dropout probability. (_not_ the keep rate!)
+ broadcast_dims: dimensions that will share the same dropout mask
+ deterministic: if false the inputs are scaled by ``1 / (1 - rate)`` and
+ masked, whereas if true, no mask is applied and the inputs are returned as
+ is.
+ rng_collection: the rng collection name to use when requesting an rng key.
+ """
+
+ rate: float | None = None
+ broadcast_dims: Sequence[int] = ()
+ deterministic: bool | None = None
+ rng_collection: str = "dropout"
+ legacy: bool = True
+
+ @compact
+ def __call__(
+ self,
+ inputs,
+ deterministic: bool | None = None,
+ rate: float | None = None,
+ rng: PRNGKey | None = None,
+ ):
+ """Applies a random dropout mask to the input.
+
+ Args:
+ inputs: the inputs that should be randomly masked.
+ deterministic: if false the inputs are scaled by ``1 / (1 - rate)`` and
+ masked, whereas if true, no mask is applied and the inputs are returned
+ as is.
+ rate: the dropout probability. (_not_ the keep rate!)
+ rng: an optional PRNGKey used as the random key, if not specified, one
+ will be generated using ``make_rng`` with the ``rng_collection`` name.
+
+ Returns:
+ The masked inputs reweighted to preserve mean.
+ """
+ deterministic = merge_param("deterministic", self.deterministic, deterministic)
+
+ # Override self.rate if rate is passed to __call__
+ if not (self.rate is not None and rate is not None):
+ rate = merge_param("rate", self.rate, rate)
+
+ if self.legacy:
+ if rate == 0.0:
+ return inputs
+
+ # Prevent gradient NaNs in 1.0 edge-case.
+ if rate == 1.0:
+ return jnp.zeros_like(inputs)
+
+ if deterministic:
+ return inputs
+
+ keep_prob = 1.0 - rate
+ if rng is None:
+ rng = self.make_rng(self.rng_collection)
+ broadcast_shape = list(inputs.shape)
+ for dim in self.broadcast_dims:
+ broadcast_shape[dim] = 1
+ mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape)
+ mask = jnp.broadcast_to(mask, inputs.shape)
+
+ return lax.select(
+ mask, jnp.nan_to_num(inputs / keep_prob), jnp.zeros_like(inputs)
+ )
+
+
+# Utilities for debugging
+def print_jax_model_summary(model, fake_inputs):
+ """Prints a summary of the jax module."""
+ tabulate_fn = nn.tabulate(
+ model,
+ jax.random.PRNGKey(0),
+ console_kwargs={"force_terminal": False, "force_jupyter": False, "width": 240},
+ )
+ print(tabulate_fn(fake_inputs, train=False))
From aa25e208ef7f283d2ee3d4ebc304e45d6c24fa8f Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Mon, 5 May 2025 08:28:01 +0000
Subject: [PATCH 002/123] remove nan_to_num convertion
---
algoperf/jax_utils.py | 5 +----
1 file changed, 1 insertion(+), 4 deletions(-)
diff --git a/algoperf/jax_utils.py b/algoperf/jax_utils.py
index ddafd77c6..3ca3f1bfc 100644
--- a/algoperf/jax_utils.py
+++ b/algoperf/jax_utils.py
@@ -104,10 +104,7 @@ def __call__(
broadcast_shape[dim] = 1
mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape)
mask = jnp.broadcast_to(mask, inputs.shape)
-
- return lax.select(
- mask, jnp.nan_to_num(inputs / keep_prob), jnp.zeros_like(inputs)
- )
+ return lax.select(mask, inputs, jnp.zeros_like(inputs))
# Utilities for debugging
From 85a35785a523ff44a1fca93e41d3cb082d72d4c2 Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Mon, 5 May 2025 09:35:39 +0000
Subject: [PATCH 003/123] update models with custom dropout layer
---
.../workloads/criteo1tb/criteo1tb_jax/models.py | 6 +++---
algoperf/workloads/fastmri/fastmri_jax/models.py | 5 +++--
.../imagenet_vit/imagenet_jax/models.py | 13 +++++++------
.../librispeech_jax/models.py | 11 ++++++-----
.../librispeech_jax/models.py | 5 +++--
algoperf/workloads/ogbg/ogbg_jax/models.py | 4 +++-
algoperf/workloads/wmt/wmt_jax/models.py | 16 +++++++++-------
7 files changed, 34 insertions(+), 26 deletions(-)
diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py
index 6d9a489ff..e89db0c86 100644
--- a/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py
+++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py
@@ -1,11 +1,11 @@
"""A JAX implementation of DLRM-Small."""
-
from typing import Sequence
import flax.linen as nn
from jax import nn as jnn
import jax.numpy as jnp
+from algoperf.jax_utils import Dropout
class DLRMResNet(nn.Module):
"""Define a DLRMResNet model.
@@ -89,7 +89,7 @@ def scaled_init(key, shape, dtype=jnp.float_):
top_mlp_input)
x = nn.relu(x)
if self.dropout_rate and layer_idx == num_layers_top - 2:
- x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x)
+ x = Dropout(rate=self.dropout_rate, deterministic=not train)(x)
top_mlp_input += x
# In the DLRM model the last layer width is always 1. We can hardcode that
# below.
@@ -212,7 +212,7 @@ def scaled_init(key, shape, dtype=jnp.float_):
top_mlp_input = nn.LayerNorm()(top_mlp_input)
if (self.dropout_rate is not None and self.dropout_rate > 0.0 and
layer_idx == num_layers_top - 2):
- top_mlp_input = nn.Dropout(
+ top_mlp_input = Dropout(
rate=self.dropout_rate, deterministic=not train)(
top_mlp_input)
logits = top_mlp_input
diff --git a/algoperf/workloads/fastmri/fastmri_jax/models.py b/algoperf/workloads/fastmri/fastmri_jax/models.py
index 44bff0e21..f29e0be22 100644
--- a/algoperf/workloads/fastmri/fastmri_jax/models.py
+++ b/algoperf/workloads/fastmri/fastmri_jax/models.py
@@ -19,6 +19,7 @@
import jax
import jax.numpy as jnp
+from algoperf.jax_utils import Dropout
def _instance_norm2d(x, axes, epsilon=1e-5):
# promote x to at least float32, this avoids half precision computation
@@ -172,7 +173,7 @@ def __call__(self, x, train=True):
x = activation_fn(x)
# Ref code uses dropout2d which applies the same mask for the entire channel
# Replicated by using broadcast dims to have the same filter on HW
- x = nn.Dropout(
+ x = Dropout(
self.dropout_rate, broadcast_dims=(1, 2), deterministic=not train)(
x)
x = nn.Conv(
@@ -186,7 +187,7 @@ def __call__(self, x, train=True):
else:
x = _instance_norm2d(x, (1, 2))
x = activation_fn(x)
- x = nn.Dropout(
+ x = Dropout(
self.dropout_rate, broadcast_dims=(1, 2), deterministic=not train)(
x)
return x
diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py
index 7ce3a0395..902658fbe 100644
--- a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py
+++ b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py
@@ -11,6 +11,7 @@
import jax.numpy as jnp
from algoperf import spec
+from algoperf.jax_utils import Dropout
def posemb_sincos_2d(h: int,
@@ -53,7 +54,7 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor:
y = nn.Dense(self.mlp_dim, **inits)(x)
x = x * y
- x = nn.Dropout(rate=self.dropout_rate)(x, train)
+ x = Dropout(rate=self.dropout_rate)(x, train)
x = nn.Dense(d, **inits)(x)
return x
@@ -76,7 +77,7 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor:
deterministic=train,
name='MultiHeadDotProductAttention_1')(
y)
- y = nn.Dropout(rate=self.dropout_rate)(y, train)
+ y = Dropout(rate=self.dropout_rate)(y, train)
x = x + y
y = nn.LayerNorm(name='LayerNorm_2')(x)
@@ -85,7 +86,7 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor:
use_glu=self.use_glu,
dropout_rate=self.dropout_rate,
name='MlpBlock_3')(y, train)
- y = nn.Dropout(rate=self.dropout_rate)(y, train)
+ y = Dropout(rate=self.dropout_rate)(y, train)
x = x + y
else:
y = x
@@ -95,7 +96,7 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor:
deterministic=train,
name='MultiHeadDotProductAttention_1')(
y)
- y = nn.Dropout(rate=self.dropout_rate)(y, train)
+ y = Dropout(rate=self.dropout_rate)(y, train)
x = x + y
x = nn.LayerNorm(name='LayerNorm_0')(x)
@@ -105,7 +106,7 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor:
use_glu=self.use_glu,
dropout_rate=self.dropout_rate,
name='MlpBlock_3')(y, train)
- y = nn.Dropout(rate=self.dropout_rate)(y, train)
+ y = Dropout(rate=self.dropout_rate)(y, train)
x = x + y
x = nn.LayerNorm(name='LayerNorm_2')(x)
@@ -205,7 +206,7 @@ def __call__(self, x: spec.Tensor, *, train: bool = False) -> spec.Tensor:
dropout_rate = self.dropout_rate
if dropout_rate is None:
dropout_rate = 0.0
- x = nn.Dropout(rate=dropout_rate)(x, not train)
+ x = Dropout(rate=dropout_rate)(x, not train)
x = Encoder(
depth=self.depth,
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py
index 593d463c3..51e93acc9 100644
--- a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py
+++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py
@@ -26,6 +26,7 @@
librispeech_preprocessor as preprocessor
from algoperf.workloads.librispeech_conformer.librispeech_jax import \
spectrum_augmenter
+from algoperf.jax_utils import Dropout
@struct.dataclass
@@ -129,7 +130,7 @@ def __call__(self, inputs, input_paddings, train):
outputs = outputs + AddPositionalEmbedding(embedding_dim=self.encoder_dim)(
seq_length=outputs.shape[1])
- outputs = nn.Dropout(
+ outputs = Dropout(
rate=self.input_dropout_rate, deterministic=not train)(
outputs)
@@ -217,7 +218,7 @@ def __call__(self, inputs, padding_mask=None, train=False):
'config.activation_function_name values, recieved '
f'{config.activation_function_name}')
inputs = activation_fn(inputs)
- inputs = nn.Dropout(rate=config.feed_forward_dropout_rate)(
+ inputs = Dropout(rate=config.feed_forward_dropout_rate)(
inputs, deterministic=not train)
inputs = inputs * padding_mask
@@ -234,7 +235,7 @@ def __call__(self, inputs, padding_mask=None, train=False):
else:
feed_forward_residual_dropout_rate = (
config.feed_forward_residual_dropout_rate)
- inputs = nn.Dropout(rate=feed_forward_residual_dropout_rate)(
+ inputs = Dropout(rate=feed_forward_residual_dropout_rate)(
inputs, deterministic=not train)
return inputs
@@ -416,7 +417,7 @@ def __call__(self, inputs, paddings, train):
attention_residual_dropout_rate = 0.1
else:
attention_residual_dropout_rate = config.attention_residual_dropout_rate
- result = nn.Dropout(
+ result = Dropout(
rate=attention_residual_dropout_rate, deterministic=not train)(
result)
@@ -578,7 +579,7 @@ def __call__(self,
conv_residual_dropout_rate = 0.0
else:
conv_residual_dropout_rate = config.conv_residual_dropout_rate
- inputs = nn.Dropout(
+ inputs = Dropout(
rate=conv_residual_dropout_rate, deterministic=not train)(
inputs)
return inputs
diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py
index b116f44cd..f937a1692 100644
--- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py
+++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py
@@ -20,6 +20,7 @@
librispeech_preprocessor as preprocessor
from algoperf.workloads.librispeech_conformer.librispeech_jax import \
spectrum_augmenter
+from algoperf.jax_utils import Dropout
Array = jnp.ndarray
StateType = Union[Array, Tuple[Array, ...]]
@@ -110,7 +111,7 @@ def __call__(self, inputs, output_paddings, train):
input_dropout_rate = 0.1
else:
input_dropout_rate = config.input_dropout_rate
- outputs = nn.Dropout(
+ outputs = Dropout(
rate=input_dropout_rate, deterministic=not train)(
outputs)
@@ -216,7 +217,7 @@ def __call__(self, inputs, input_paddings=None, train=False):
feed_forward_dropout_rate = 0.1
else:
feed_forward_dropout_rate = config.feed_forward_dropout_rate
- inputs = nn.Dropout(rate=feed_forward_dropout_rate)(
+ inputs = Dropout(rate=feed_forward_dropout_rate)(
inputs, deterministic=not train)
return inputs
diff --git a/algoperf/workloads/ogbg/ogbg_jax/models.py b/algoperf/workloads/ogbg/ogbg_jax/models.py
index 0e66d2ab8..f5710a3ab 100644
--- a/algoperf/workloads/ogbg/ogbg_jax/models.py
+++ b/algoperf/workloads/ogbg/ogbg_jax/models.py
@@ -6,6 +6,8 @@
import jax.numpy as jnp
import jraph
+from algoperf.jax_utils import Dropout
+
def _make_embed(latent_dim, name):
@@ -50,7 +52,7 @@ def __call__(self, graph, train):
dropout_rate = 0.1
else:
dropout_rate = self.dropout_rate
- dropout = nn.Dropout(rate=dropout_rate, deterministic=not train)
+ dropout = Dropout(rate=dropout_rate, deterministic=not train)
graph = graph._replace(
globals=jnp.zeros([graph.n_node.shape[0], self.num_outputs]))
diff --git a/algoperf/workloads/wmt/wmt_jax/models.py b/algoperf/workloads/wmt/wmt_jax/models.py
index 97fee032f..04f46e8ac 100644
--- a/algoperf/workloads/wmt/wmt_jax/models.py
+++ b/algoperf/workloads/wmt/wmt_jax/models.py
@@ -11,6 +11,8 @@
import jax.numpy as jnp
import numpy as np
+from algoperf.jax_utils import Dropout
+
@struct.dataclass
class TransformerConfig:
@@ -172,14 +174,14 @@ def __call__(self, inputs):
dropout_rate = 0.1
else:
dropout_rate = cfg.dropout_rate
- x = nn.Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic)
+ x = Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic)
output = nn.Dense(
actual_out_dim,
dtype=cfg.dtype,
kernel_init=cfg.kernel_init,
bias_init=cfg.bias_init)(
x)
- output = nn.Dropout(rate=dropout_rate)(
+ output = Dropout(rate=dropout_rate)(
output, deterministic=cfg.deterministic)
return output
@@ -229,7 +231,7 @@ def __call__(self, inputs, encoder_mask=None):
dropout_rate = 0.1
else:
dropout_rate = cfg.dropout_rate
- x = nn.Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic)
+ x = Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic)
x = x + inputs
if not pre_ln:
x = nn.LayerNorm(dtype=cfg.dtype)(x)
@@ -293,7 +295,7 @@ def __call__(self,
dropout_rate = 0.1
else:
dropout_rate = cfg.dropout_rate
- x = nn.Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic)
+ x = Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic)
x = x + targets
if not pre_ln:
x = nn.LayerNorm(dtype=cfg.dtype)(x)
@@ -312,7 +314,7 @@ def __call__(self,
deterministic=cfg.deterministic)(
cfg.attention_temp * y, encoded, mask=encoder_decoder_mask)
- y = nn.Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic)
+ y = Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic)
y = y + x
if not pre_ln:
y = nn.LayerNorm(dtype=cfg.dtype)(y)
@@ -366,7 +368,7 @@ def __call__(self, inputs, inputs_positions=None, encoder_mask=None):
dropout_rate = 0.1
else:
dropout_rate = cfg.dropout_rate
- x = nn.Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic)
+ x = Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic)
x = x.astype(cfg.dtype)
@@ -436,7 +438,7 @@ def __call__(self,
dropout_rate = 0.1
else:
dropout_rate = cfg.dropout_rate
- y = nn.Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic)
+ y = Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic)
y = y.astype(cfg.dtype)
From 9354079c1a97b51fe722069b29608953f58aa107 Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Mon, 5 May 2025 10:36:48 +0000
Subject: [PATCH 004/123] add functional dropout for criteo, fastmri, and vit
---
.../criteo1tb/criteo1tb_jax/models.py | 21 +++++----
.../workloads/fastmri/fastmri_jax/models.py | 20 ++++-----
.../imagenet_vit/imagenet_jax/models.py | 45 ++++++++++---------
3 files changed, 48 insertions(+), 38 deletions(-)
diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py
index e89db0c86..c56748bb1 100644
--- a/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py
+++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py
@@ -28,7 +28,10 @@ class DLRMResNet(nn.Module):
embedding_init_multiplier: float = None # Unused
@nn.compact
- def __call__(self, x, train):
+ def __call__(self, x, train, dropout_rate=None):
+ if not dropout_rate:
+ dropout_rate=self.dropout_rate
+
bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1)
cat_features = jnp.asarray(cat_features, dtype=jnp.int32)
@@ -88,8 +91,8 @@ def scaled_init(key, shape, dtype=jnp.float_):
stddev=jnp.sqrt(1.0 / mlp_top_dims[layer_idx])))(
top_mlp_input)
x = nn.relu(x)
- if self.dropout_rate and layer_idx == num_layers_top - 2:
- x = Dropout(rate=self.dropout_rate, deterministic=not train)(x)
+ if dropout_rate and layer_idx == num_layers_top - 2:
+ x = Dropout(deterministic=not train)(x, rate=dropout_rate)
top_mlp_input += x
# In the DLRM model the last layer width is always 1. We can hardcode that
# below.
@@ -151,7 +154,10 @@ class DlrmSmall(nn.Module):
embedding_init_multiplier: float = None
@nn.compact
- def __call__(self, x, train):
+ def __call__(self, x, train, dropout_rate=None):
+ if not dropout_rate:
+ dropout_rate = self.dropout_rate
+
bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1)
cat_features = jnp.asarray(cat_features, dtype=jnp.int32)
@@ -210,10 +216,9 @@ def scaled_init(key, shape, dtype=jnp.float_):
top_mlp_input = nn.relu(top_mlp_input)
if self.use_layer_norm:
top_mlp_input = nn.LayerNorm()(top_mlp_input)
- if (self.dropout_rate is not None and self.dropout_rate > 0.0 and
+ if (dropout_rate is not None and dropout_rate > 0.0 and
layer_idx == num_layers_top - 2):
- top_mlp_input = Dropout(
- rate=self.dropout_rate, deterministic=not train)(
- top_mlp_input)
+ top_mlp_input = Dropout(deterministic=not train)(
+ top_mlp_input, rate=dropout_rate)
logits = top_mlp_input
return logits
diff --git a/algoperf/workloads/fastmri/fastmri_jax/models.py b/algoperf/workloads/fastmri/fastmri_jax/models.py
index f29e0be22..3d5460c18 100644
--- a/algoperf/workloads/fastmri/fastmri_jax/models.py
+++ b/algoperf/workloads/fastmri/fastmri_jax/models.py
@@ -62,10 +62,9 @@ class UNet(nn.Module):
use_layer_norm: bool = False
@nn.compact
- def __call__(self, x, train=True):
- dropout_rate = self.dropout_rate
- if dropout_rate is None:
- dropout_rate = 0.0
+ def __call__(self, x, train=True, dropout_rate=None):
+ if not dropout_rate:
+ dropout_rate = self.dropout_rate
# pylint: disable=invalid-name
_ConvBlock = functools.partial(
@@ -144,7 +143,7 @@ class ConvBlock(nn.Module):
use_layer_norm: bool
@nn.compact
- def __call__(self, x, train=True):
+ def __call__(self, x, train=True, dropout_rate=None):
"""Forward function.
Note: Pytorch is NCHW and jax/flax is NHWC.
Args:
@@ -153,6 +152,8 @@ def __call__(self, x, train=True):
Returns:
jnp.array: Output tensor of shape `(N, H, W, out_channels)`.
"""
+ if not dropout_rate:
+ dropout_rate=self.dropout_rate
x = nn.Conv(
features=self.out_channels,
kernel_size=(3, 3),
@@ -173,9 +174,8 @@ def __call__(self, x, train=True):
x = activation_fn(x)
# Ref code uses dropout2d which applies the same mask for the entire channel
# Replicated by using broadcast dims to have the same filter on HW
- x = Dropout(
- self.dropout_rate, broadcast_dims=(1, 2), deterministic=not train)(
- x)
+ x = Dropout(broadcast_dims=(1, 2), deterministic=not train)(
+ x, rate=dropout_rate )
x = nn.Conv(
features=self.out_channels,
kernel_size=(3, 3),
@@ -188,8 +188,8 @@ def __call__(self, x, train=True):
x = _instance_norm2d(x, (1, 2))
x = activation_fn(x)
x = Dropout(
- self.dropout_rate, broadcast_dims=(1, 2), deterministic=not train)(
- x)
+ broadcast_dims=(1, 2), deterministic=not train)(
+ x, rate=dropout_rate)
return x
diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py
index 902658fbe..10a90f37d 100644
--- a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py
+++ b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py
@@ -39,8 +39,11 @@ class MlpBlock(nn.Module):
dropout_rate: float = 0.0
@nn.compact
- def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor:
+ def __call__(self, x: spec.Tensor, train: bool = True, dropout_rate=None) -> spec.Tensor:
"""Applies Transformer MlpBlock module."""
+ if not dropout_rate:
+ dropout_rate = self.dropout_rate
+
inits = {
'kernel_init': nn.initializers.xavier_uniform(),
'bias_init': nn.initializers.normal(stddev=1e-6),
@@ -54,7 +57,7 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor:
y = nn.Dense(self.mlp_dim, **inits)(x)
x = x * y
- x = Dropout(rate=self.dropout_rate)(x, train)
+ x = Dropout()(x, train, rate=dropout_rate)
x = nn.Dense(d, **inits)(x)
return x
@@ -68,7 +71,10 @@ class Encoder1DBlock(nn.Module):
dropout_rate: float = 0.0
@nn.compact
- def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor:
+ def __call__(self, x: spec.Tensor, train: bool = True, dropout_rate=dropout_rate) -> spec.Tensor:
+ if not dropout_rate:
+ dropout_rate=self.dropout_rate
+
if not self.use_post_layer_norm:
y = nn.LayerNorm(name='LayerNorm_0')(x)
y = nn.MultiHeadDotProductAttention(
@@ -77,16 +83,15 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor:
deterministic=train,
name='MultiHeadDotProductAttention_1')(
y)
- y = Dropout(rate=self.dropout_rate)(y, train)
+ y = Dropout()(y, train, dropout_rate=dropout_rate)
x = x + y
y = nn.LayerNorm(name='LayerNorm_2')(x)
y = MlpBlock(
mlp_dim=self.mlp_dim,
use_glu=self.use_glu,
- dropout_rate=self.dropout_rate,
- name='MlpBlock_3')(y, train)
- y = Dropout(rate=self.dropout_rate)(y, train)
+ name='MlpBlock_3')(y, train, dropout_rate=dropout_rate)
+ y = Dropout()(y, train, rate=dropout_rate)
x = x + y
else:
y = x
@@ -96,7 +101,7 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor:
deterministic=train,
name='MultiHeadDotProductAttention_1')(
y)
- y = Dropout(rate=self.dropout_rate)(y, train)
+ y = Dropout()(y, train, rate=dropout_rate)
x = x + y
x = nn.LayerNorm(name='LayerNorm_0')(x)
@@ -104,9 +109,8 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor:
y = MlpBlock(
mlp_dim=self.mlp_dim,
use_glu=self.use_glu,
- dropout_rate=self.dropout_rate,
- name='MlpBlock_3')(y, train)
- y = Dropout(rate=self.dropout_rate)(y, train)
+ name='MlpBlock_3')(y, train, dropout_rate=dropout_rate)
+ y = Dropout()(y, train)(rate=dropout_rate)
x = x + y
x = nn.LayerNorm(name='LayerNorm_2')(x)
@@ -123,7 +127,10 @@ class Encoder(nn.Module):
use_post_layer_norm: bool = False
@nn.compact
- def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor:
+ def __call__(self, x: spec.Tensor, train: bool = True, dropout_rate=None) -> spec.Tensor:
+ if not dropout_rate:
+ dropout_rate=self.dropout_rate
+
# Input Encoder
for lyr in range(self.depth):
block = Encoder1DBlock(
@@ -132,7 +139,7 @@ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor:
num_heads=self.num_heads,
use_glu=self.use_glu,
use_post_layer_norm=self.use_post_layer_norm,
- dropout_rate=self.dropout_rate)
+ )(dropout_rate=dropout_rate)
x = block(x, train)
if not self.use_post_layer_norm:
return nn.LayerNorm(name='encoder_layernorm')(x)
@@ -187,7 +194,9 @@ def get_posemb(self,
return posemb_sincos_2d(*seqshape, width, dtype=dtype)
@nn.compact
- def __call__(self, x: spec.Tensor, *, train: bool = False) -> spec.Tensor:
+ def __call__(self, x: spec.Tensor, *, train: bool = False, dropout_rate=None) -> spec.Tensor:
+ if not dropout_rate:
+ dropout_rate = self.dropout_rate
# Patch extraction
x = nn.Conv(
self.width,
@@ -203,10 +212,7 @@ def __call__(self, x: spec.Tensor, *, train: bool = False) -> spec.Tensor:
# Add posemb before adding extra token.
x = x + self.get_posemb((h, w), c, x.dtype)
- dropout_rate = self.dropout_rate
- if dropout_rate is None:
- dropout_rate = 0.0
- x = Dropout(rate=dropout_rate)(x, not train)
+ x = Dropout()(x, not train, rate=dropout_rate)
x = Encoder(
depth=self.depth,
@@ -214,9 +220,8 @@ def __call__(self, x: spec.Tensor, *, train: bool = False) -> spec.Tensor:
num_heads=self.num_heads,
use_glu=self.use_glu,
use_post_layer_norm=self.use_post_layer_norm,
- dropout_rate=dropout_rate,
name='Transformer')(
- x, train=not train)
+ x, train=not train, dropout_rate=dropout_rate)
if self.use_map:
x = MAPHead(num_heads=self.num_heads, mlp_dim=self.mlp_dim)(x)
From feb9cc5cfd57575e11c91039c0292d26014e99fe Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Mon, 5 May 2025 10:48:34 +0000
Subject: [PATCH 005/123] add functional dropout for ogbg
---
algoperf/workloads/ogbg/ogbg_jax/models.py | 8 +++-----
1 file changed, 3 insertions(+), 5 deletions(-)
diff --git a/algoperf/workloads/ogbg/ogbg_jax/models.py b/algoperf/workloads/ogbg/ogbg_jax/models.py
index f5710a3ab..6ced9bef5 100644
--- a/algoperf/workloads/ogbg/ogbg_jax/models.py
+++ b/algoperf/workloads/ogbg/ogbg_jax/models.py
@@ -47,12 +47,10 @@ class GNN(nn.Module):
activation_fn_name: str = 'relu'
@nn.compact
- def __call__(self, graph, train):
- if self.dropout_rate is None:
- dropout_rate = 0.1
- else:
+ def __call__(self, graph, train, dropout_rate=None):
+ if not dropout_rate:
dropout_rate = self.dropout_rate
- dropout = Dropout(rate=dropout_rate, deterministic=not train)
+ dropout = Dropout(deterministic=not train, rate=dropout_rate)
graph = graph._replace(
globals=jnp.zeros([graph.n_node.shape[0], self.num_outputs]))
From 9bba078b5e17a7881ffaa294a551129d4acf5c65 Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Thu, 15 May 2025 21:43:18 +0000
Subject: [PATCH 006/123] modify wmt model for dropout passing
---
algoperf/workloads/wmt/wmt_jax/models.py | 599 +++++++++++----------
algoperf/workloads/wmt/wmt_jax/workload.py | 4 +-
2 files changed, 310 insertions(+), 293 deletions(-)
diff --git a/algoperf/workloads/wmt/wmt_jax/models.py b/algoperf/workloads/wmt/wmt_jax/models.py
index 04f46e8ac..a5b484320 100644
--- a/algoperf/workloads/wmt/wmt_jax/models.py
+++ b/algoperf/workloads/wmt/wmt_jax/models.py
@@ -140,325 +140,342 @@ def __call__(self, inputs, inputs_positions=None):
class MlpBlock(nn.Module):
- """Transformer MLP / feed-forward block.
+ """Transformer MLP / feed-forward block.
- Attributes:
- config: TransformerConfig dataclass containing hyperparameters.
- out_dim: optionally specify out dimension.
- """
- config: TransformerConfig
- out_dim: Optional[int] = None
+ Attributes:
+ config: TransformerConfig dataclass containing hyperparameters.
+ out_dim: optionally specify out dimension.
+ """
- @nn.compact
- def __call__(self, inputs):
- """Applies Transformer MlpBlock module."""
- cfg = self.config
- actual_out_dim = (
- inputs.shape[-1] if self.out_dim is None else self.out_dim)
- x = nn.Dense(
- cfg.mlp_dim,
- dtype=cfg.dtype,
- kernel_init=cfg.kernel_init,
- bias_init=cfg.bias_init)(
- inputs)
- x = cfg.activation(x)
- if cfg.glu:
- y = nn.Dense(
- cfg.mlp_dim,
- dtype=cfg.dtype,
- kernel_init=cfg.kernel_init,
- bias_init=cfg.bias_init)(
- inputs)
- x = x * y
- if cfg.dropout_rate is None:
- dropout_rate = 0.1
- else:
- dropout_rate = cfg.dropout_rate
- x = Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic)
- output = nn.Dense(
- actual_out_dim,
- dtype=cfg.dtype,
- kernel_init=cfg.kernel_init,
- bias_init=cfg.bias_init)(
- x)
- output = Dropout(rate=dropout_rate)(
- output, deterministic=cfg.deterministic)
- return output
+ config: TransformerConfig
+ out_dim: Optional[int] = None
+
+ @nn.compact
+ def __call__(self, inputs, dropout_rate=None):
+ """Applies Transformer MlpBlock module."""
+ cfg = self.config
+ actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim
+ x = nn.Dense(
+ cfg.mlp_dim,
+ dtype=cfg.dtype,
+ kernel_init=cfg.kernel_init,
+ bias_init=cfg.bias_init,
+ )(inputs)
+ x = cfg.activation(x)
+ if cfg.glu:
+ y = nn.Dense(
+ cfg.mlp_dim,
+ dtype=cfg.dtype,
+ kernel_init=cfg.kernel_init,
+ bias_init=cfg.bias_init,
+ )(inputs)
+ x = x * y
+ if dropout_rate is None:
+ if cfg.dropout_rate is None:
+ dropout_rate = 0.1
+ else:
+ dropout_rate = cfg.dropout_rate
+ x = Dropout()(x, rate=dropout_rate, deterministic=cfg.deterministic)
+ output = nn.Dense(
+ actual_out_dim,
+ dtype=cfg.dtype,
+ kernel_init=cfg.kernel_init,
+ bias_init=cfg.bias_init,
+ )(x)
+ output = Dropout()(output, rate=dropout_rate, deterministic=cfg.deterministic)
+ return output
class Encoder1DBlock(nn.Module):
- """Transformer encoder layer.
-
- Attributes:
- config: TransformerConfig dataclass containing hyperparameters.
- """
- config: TransformerConfig
-
- @nn.compact
- def __call__(self, inputs, encoder_mask=None):
- """Applies Encoder1DBlock module.
+ """Transformer encoder layer.
- Args:
- inputs: input data.
- encoder_mask: encoder self-attention mask.
-
- Returns:
- output after transformer encoder block.
+ Attributes:
+ config: TransformerConfig dataclass containing hyperparameters.
"""
- cfg = self.config
- pre_ln = cfg.pre_ln
- # Attention block.
- assert inputs.ndim == 3
- x = nn.LayerNorm(dtype=cfg.dtype)(inputs) if pre_ln else inputs
- if cfg.attention_dropout_rate is None:
- attention_dropout_rate = 0.1
- else:
- attention_dropout_rate = cfg.attention_dropout_rate
- x = nn.MultiHeadDotProductAttention(
- num_heads=cfg.num_heads,
- dtype=cfg.dtype,
- qkv_features=cfg.qkv_dim,
- kernel_init=cfg.kernel_init,
- bias_init=cfg.bias_init,
- use_bias=False,
- broadcast_dropout=False,
- dropout_rate=attention_dropout_rate,
- deterministic=cfg.deterministic)(
- cfg.attention_temp * x, x, mask=encoder_mask)
-
- if cfg.dropout_rate is None:
- dropout_rate = 0.1
- else:
- dropout_rate = cfg.dropout_rate
- x = Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic)
- x = x + inputs
- if not pre_ln:
- x = nn.LayerNorm(dtype=cfg.dtype)(x)
-
- # MLP block.
- y = nn.LayerNorm(dtype=cfg.dtype)(x) if pre_ln else x
- y = MlpBlock(config=cfg)(y)
-
- return x + y if pre_ln else nn.LayerNorm(dtype=cfg.dtype)(x + y)
+ config: TransformerConfig
+
+ @nn.compact
+ def __call__(self, inputs, encoder_mask=None, dropout_rate=None):
+ """Applies Encoder1DBlock module.
+
+ Args:
+ inputs: input data.
+ encoder_mask: encoder self-attention mask.
+
+ Returns:
+ output after transformer encoder block.
+ """
+ cfg = self.config
+ pre_ln = cfg.pre_ln
+
+ # Attention block.
+ assert inputs.ndim == 3
+ x = nn.LayerNorm(dtype=cfg.dtype)(inputs) if pre_ln else inputs
+ if dropout_rate is None:
+ if cfg.attention_dropout_rate is None:
+ dropout_rate = 0.1
+ else:
+ dropout_rate = cfg.dropout_rate
+ x = nn.MultiHeadDotProductAttention(
+ num_heads=cfg.num_heads,
+ dtype=cfg.dtype,
+ qkv_features=cfg.qkv_dim,
+ kernel_init=cfg.kernel_init,
+ bias_init=cfg.bias_init,
+ use_bias=False,
+ broadcast_dropout=False,
+ dropout_rate=dropout_rate,
+ deterministic=cfg.deterministic,
+ )(cfg.attention_temp * x, x, mask=encoder_mask)
+
+ x = Dropout()(x, deterministic=cfg.deterministic, rate=dropout_rate)
+ x = x + inputs
+ if not pre_ln:
+ x = nn.LayerNorm(dtype=cfg.dtype)(x)
+
+ # MLP block.
+ y = nn.LayerNorm(dtype=cfg.dtype)(x) if pre_ln else x
+ y = MlpBlock(config=cfg)(y)
+
+ return x + y if pre_ln else nn.LayerNorm(dtype=cfg.dtype)(x + y)
class EncoderDecoder1DBlock(nn.Module):
- """Transformer encoder-decoder layer.
-
- Attributes:
- config: TransformerConfig dataclass containing hyperparameters.
- """
- config: TransformerConfig
-
- @nn.compact
- def __call__(self,
- targets,
- encoded,
- decoder_mask=None,
- encoder_decoder_mask=None):
- """Applies EncoderDecoder1DBlock module.
+ """Transformer encoder-decoder layer.
- Args:
- targets: input data for decoder
- encoded: input data from encoder
- decoder_mask: decoder self-attention mask.
- encoder_decoder_mask: encoder-decoder attention mask.
-
- Returns:
- output after transformer encoder-decoder block.
+ Attributes:
+ config: TransformerConfig dataclass containing hyperparameters.
"""
- cfg = self.config
- pre_ln = cfg.pre_ln
- # Decoder block.
- assert targets.ndim == 3
- x = nn.LayerNorm(dtype=cfg.dtype)(targets) if pre_ln else targets
+ config: TransformerConfig
- if cfg.attention_dropout_rate is None:
- attention_dropout_rate = 0.1
- else:
- attention_dropout_rate = cfg.attention_dropout_rate
- x = nn.MultiHeadDotProductAttention(
- num_heads=cfg.num_heads,
- dtype=cfg.dtype,
- qkv_features=cfg.qkv_dim,
- kernel_init=cfg.kernel_init,
- bias_init=cfg.bias_init,
- use_bias=False,
- broadcast_dropout=False,
- dropout_rate=attention_dropout_rate,
- deterministic=cfg.deterministic,
- decode=cfg.decode)(
- cfg.attention_temp * x, x, mask=decoder_mask)
- if cfg.dropout_rate is None:
- dropout_rate = 0.1
- else:
- dropout_rate = cfg.dropout_rate
- x = Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic)
- x = x + targets
- if not pre_ln:
- x = nn.LayerNorm(dtype=cfg.dtype)(x)
-
- # Encoder-Decoder block.
- y = nn.LayerNorm(dtype=cfg.dtype)(x) if pre_ln else x
- y = nn.MultiHeadDotProductAttention(
- num_heads=cfg.num_heads,
- dtype=cfg.dtype,
- qkv_features=cfg.qkv_dim,
- kernel_init=cfg.kernel_init,
- bias_init=cfg.bias_init,
- use_bias=False,
- broadcast_dropout=False,
- dropout_rate=attention_dropout_rate,
- deterministic=cfg.deterministic)(
- cfg.attention_temp * y, encoded, mask=encoder_decoder_mask)
-
- y = Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic)
- y = y + x
- if not pre_ln:
- y = nn.LayerNorm(dtype=cfg.dtype)(y)
-
- # MLP block.
- z = nn.LayerNorm(dtype=cfg.dtype)(y) if pre_ln else y
- z = MlpBlock(config=cfg)(z)
-
- return y + z if pre_ln else nn.LayerNorm(dtype=cfg.dtype)(y + z)
+ @nn.compact
+ def __call__(
+ self,
+ targets,
+ encoded,
+ decoder_mask=None,
+ encoder_decoder_mask=None,
+ dropout_rate=None,
+ ):
+ """Applies EncoderDecoder1DBlock module.
+
+ Args:
+ targets: input data for decoder
+ encoded: input data from encoder
+ decoder_mask: decoder self-attention mask.
+ encoder_decoder_mask: encoder-decoder attention mask.
+
+ Returns:
+ output after transformer encoder-decoder block.
+ """
+ cfg = self.config
+ pre_ln = cfg.pre_ln
+
+ # Decoder block.
+ assert targets.ndim == 3
+ x = nn.LayerNorm(dtype=cfg.dtype)(targets) if pre_ln else targets
+
+ if dropout_rate is None:
+ if cfg.attention_dropout_rate is None:
+ dropout_rate = 0.1
+ else:
+ dropout_rate = cfg.dropout_rate
+ x = nn.MultiHeadDotProductAttention(
+ num_heads=cfg.num_heads,
+ dtype=cfg.dtype,
+ qkv_features=cfg.qkv_dim,
+ kernel_init=cfg.kernel_init,
+ bias_init=cfg.bias_init,
+ use_bias=False,
+ broadcast_dropout=False,
+ dropout_rate=dropout_rate,
+ deterministic=cfg.deterministic,
+ decode=cfg.decode,
+ )(cfg.attention_temp * x, x, mask=decoder_mask)
+ if cfg.dropout_rate is None:
+ dropout_rate = 0.1
+ else:
+ dropout_rate = cfg.dropout_rate
+ x = Dropout()(x, deterministic=cfg.deterministic, rate=dropout_rate)
+ x = x + targets
+ if not pre_ln:
+ x = nn.LayerNorm(dtype=cfg.dtype)(x)
+
+ # Encoder-Decoder block.
+ y = nn.LayerNorm(dtype=cfg.dtype)(x) if pre_ln else x
+ y = nn.MultiHeadDotProductAttention(
+ num_heads=cfg.num_heads,
+ dtype=cfg.dtype,
+ qkv_features=cfg.qkv_dim,
+ kernel_init=cfg.kernel_init,
+ bias_init=cfg.bias_init,
+ use_bias=False,
+ broadcast_dropout=False,
+ dropout_rate=dropout_rate,
+ deterministic=cfg.deterministic,
+ )(cfg.attention_temp * y, encoded, mask=encoder_decoder_mask)
+
+ y = Dropout()(y, deterministic=cfg.deterministic, rate=dropout_rate)
+ y = y + x
+ if not pre_ln:
+ y = nn.LayerNorm(dtype=cfg.dtype)(y)
+
+ # MLP block.
+ z = nn.LayerNorm(dtype=cfg.dtype)(y) if pre_ln else y
+ z = MlpBlock(config=cfg)(z)
+
+ return y + z if pre_ln else nn.LayerNorm(dtype=cfg.dtype)(y + z)
class Encoder(nn.Module):
- """Transformer Model Encoder for sequence to sequence translation.
-
- Attributes:
- config: TransformerConfig dataclass containing hyperparameters.
- shared_embedding: a shared embedding layer to use.
- """
- config: TransformerConfig
- shared_embedding: Any = None
-
- @nn.compact
- def __call__(self, inputs, inputs_positions=None, encoder_mask=None):
- """Applies Transformer model on the inputs.
+ """Transformer Model Encoder for sequence to sequence translation.
- Args:
- inputs: input data
- inputs_positions: input subsequence positions for packed examples.
- encoder_mask: decoder self-attention mask.
-
- Returns:
- output of a transformer encoder.
+ Attributes:
+ config: TransformerConfig dataclass containing hyperparameters.
+ shared_embedding: a shared embedding layer to use.
"""
- cfg = self.config
- assert inputs.ndim == 2 # (batch, len)
- # Input Embedding
- if self.shared_embedding is None:
- input_embed = nn.Embed(
- num_embeddings=cfg.vocab_size,
- features=cfg.emb_dim,
- embedding_init=nn.initializers.normal(stddev=1.0))
- else:
- input_embed = self.shared_embedding
- x = inputs.astype('int32')
- x = input_embed(x)
- x = AddPositionEmbs(
- config=cfg, decode=False, name='posembed_input')(
- x, inputs_positions=inputs_positions)
- if cfg.dropout_rate is None:
- dropout_rate = 0.1
- else:
- dropout_rate = cfg.dropout_rate
- x = Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic)
-
- x = x.astype(cfg.dtype)
-
- # Input Encoder
- for lyr in range(cfg.num_layers):
- x = Encoder1DBlock(
- config=cfg, name=f'encoderblock_{lyr}')(x, encoder_mask)
-
- encoded = (
- nn.LayerNorm(dtype=cfg.dtype, name='encoder_layernorm')(x)
- if cfg.pre_ln else x)
-
- return encoded
+ config: TransformerConfig
+ shared_embedding: Any = None
+
+ @nn.compact
+ def __call__(
+ self, inputs, inputs_positions=None, encoder_mask=None, dropout_rate=None
+ ):
+ """Applies Transformer model on the inputs.
+
+ Args:
+ inputs: input data
+ inputs_positions: input subsequence positions for packed examples.
+ encoder_mask: decoder self-attention mask.
+
+ Returns:
+ output of a transformer encoder.
+ """
+ cfg = self.config
+ assert inputs.ndim == 2 # (batch, len)
+
+ # Input Embedding
+ if self.shared_embedding is None:
+ input_embed = nn.Embed(
+ num_embeddings=cfg.vocab_size,
+ features=cfg.emb_dim,
+ embedding_init=nn.initializers.normal(stddev=1.0),
+ )
+ else:
+ input_embed = self.shared_embedding
+ x = inputs.astype("int32")
+ x = input_embed(x)
+ x = AddPositionEmbs(config=cfg, decode=False, name="posembed_input")(
+ x, inputs_positions=inputs_positions
+ )
+ if dropout_rate is None:
+ if cfg.dropout_rate is None:
+ dropout_rate = 0.1
+ else:
+ dropout_rate = cfg.dropout_rate
+ x = Dropout()(x, deterministic=cfg.deterministic, rate=dropout_rate)
+
+ x = x.astype(cfg.dtype)
+
+ # Input Encoder
+ for lyr in range(cfg.num_layers):
+ x = Encoder1DBlock(config=cfg, name=f"encoderblock_{lyr}")(x, encoder_mask)
+
+ encoded = (
+ nn.LayerNorm(dtype=cfg.dtype, name="encoder_layernorm")(x)
+ if cfg.pre_ln
+ else x
+ )
+
+ return encoded
class Decoder(nn.Module):
- """Transformer Model Decoder for sequence to sequence translation.
+ """Transformer Model Decoder for sequence to sequence translation.
- Attributes:
- config: TransformerConfig dataclass containing hyperparameters.
- shared_embedding: a shared embedding layer to use.
- """
- config: TransformerConfig
- shared_embedding: Any = None
-
- @nn.compact
- def __call__(self,
- encoded,
- targets,
- targets_positions=None,
- decoder_mask=None,
- encoder_decoder_mask=None):
- """Applies Transformer model on the inputs.
-
- Args:
- encoded: encoded input data from encoder.
- targets: target inputs.
- targets_positions: input subsequence positions for packed examples.
- decoder_mask: decoder self-attention mask.
- encoder_decoder_mask: encoder-decoder attention mask.
-
- Returns:
- output of a transformer decoder.
+ Attributes:
+ config: TransformerConfig dataclass containing hyperparameters.
+ shared_embedding: a shared embedding layer to use.
"""
- cfg = self.config
- assert encoded.ndim == 3 # (batch, len, depth)
- assert targets.ndim == 2 # (batch, len)
+ config: TransformerConfig
+ shared_embedding: Any = None
- # Target Embedding
- if self.shared_embedding is None:
- output_embed = nn.Embed(
- num_embeddings=cfg.vocab_size,
- features=cfg.emb_dim,
- embedding_init=nn.initializers.normal(stddev=1.0))
- else:
- output_embed = self.shared_embedding
-
- y = targets.astype('int32')
- if not cfg.decode:
- y = shift_right(y)
- y = output_embed(y)
- y = AddPositionEmbs(
- config=cfg, decode=cfg.decode, name='posembed_output')(
- y, inputs_positions=targets_positions)
- if cfg.dropout_rate is None:
- dropout_rate = 0.1
- else:
- dropout_rate = cfg.dropout_rate
- y = Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic)
-
- y = y.astype(cfg.dtype)
-
- # Target-Input Decoder
- for lyr in range(cfg.num_layers):
- y = EncoderDecoder1DBlock(
- config=cfg, name=f'encoderdecoderblock_{lyr}')(
- y,
- encoded,
- decoder_mask=decoder_mask,
- encoder_decoder_mask=encoder_decoder_mask)
- y = (
- nn.LayerNorm(dtype=cfg.dtype, name='encoderdecoder_layernorm')(y)
- if cfg.pre_ln else y)
-
- # Use the transpose of embedding matrix for logit transform.
- logits = output_embed.attend(y.astype(jnp.float32))
- # Correctly normalize pre-softmax logits for this shared case.
- logits = logits / jnp.sqrt(y.shape[-1])
- return logits
+ @nn.compact
+ def __call__(
+ self,
+ encoded,
+ targets,
+ targets_positions=None,
+ decoder_mask=None,
+ encoder_decoder_mask=None,
+ dropout_rate=None,
+ ):
+ """Applies Transformer model on the inputs.
+
+ Args:
+ encoded: encoded input data from encoder.
+ targets: target inputs.
+ targets_positions: input subsequence positions for packed examples.
+ decoder_mask: decoder self-attention mask.
+ encoder_decoder_mask: encoder-decoder attention mask.
+
+ Returns:
+ output of a transformer decoder.
+ """
+ cfg = self.config
+
+ assert encoded.ndim == 3 # (batch, len, depth)
+ assert targets.ndim == 2 # (batch, len)
+
+ # Target Embedding
+ if self.shared_embedding is None:
+ output_embed = nn.Embed(
+ num_embeddings=cfg.vocab_size,
+ features=cfg.emb_dim,
+ embedding_init=nn.initializers.normal(stddev=1.0),
+ )
+ else:
+ output_embed = self.shared_embedding
+
+ y = targets.astype("int32")
+ if not cfg.decode:
+ y = shift_right(y)
+ y = output_embed(y)
+ y = AddPositionEmbs(config=cfg, decode=cfg.decode, name="posembed_output")(
+ y, inputs_positions=targets_positions
+ )
+ if dropout_rate is None:
+ if cfg.dropout_rate is None:
+ dropout_rate = 0.1
+ else:
+ dropout_rate = cfg.dropout_rate
+ y = Dropout()(y, deterministic=cfg.deterministic, rate=dropout_rate)
+
+ y = y.astype(cfg.dtype)
+
+ # Target-Input Decoder
+ for lyr in range(cfg.num_layers):
+ y = EncoderDecoder1DBlock(config=cfg, name=f"encoderdecoderblock_{lyr}")(
+ y,
+ encoded,
+ decoder_mask=decoder_mask,
+ encoder_decoder_mask=encoder_decoder_mask,
+ )
+ y = (
+ nn.LayerNorm(dtype=cfg.dtype, name="encoderdecoder_layernorm")(y)
+ if cfg.pre_ln
+ else y
+ )
+
+ # Use the transpose of embedding matrix for logit transform.
+ logits = output_embed.attend(y.astype(jnp.float32))
+ # Correctly normalize pre-softmax logits for this shared case.
+ logits = logits / jnp.sqrt(y.shape[-1])
+ return logits
class Transformer(nn.Module):
diff --git a/algoperf/workloads/wmt/wmt_jax/workload.py b/algoperf/workloads/wmt/wmt_jax/workload.py
index cdfcb91df..240ad2c11 100644
--- a/algoperf/workloads/wmt/wmt_jax/workload.py
+++ b/algoperf/workloads/wmt/wmt_jax/workload.py
@@ -209,8 +209,8 @@ def translate_and_calculate_bleu(self,
def init_model_fn(
self,
rng: spec.RandomState,
- dropout_rate: Optional[float] = None,
- aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
+ dropout_rate: Optional[float] = 0.0,
+ aux_dropout_rate: Optional[float] = 0.0) -> spec.ModelInitState:
"""aux_dropout_rate is used as attention_dropout_rate."""
init_fake_batch_size = 2
From 31f601977335e170c54e50d37ece29ecaea9a314 Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Thu, 15 May 2025 22:19:20 +0000
Subject: [PATCH 007/123] modify wmt model for dropout passing
---
.../librispeech_jax/models.py | 2 +-
algoperf/workloads/wmt/wmt_jax/models.py | 20 +++++++++++--------
2 files changed, 13 insertions(+), 9 deletions(-)
diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py
index f937a1692..003bf4ea7 100644
--- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py
+++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py
@@ -73,7 +73,7 @@ class Subsample(nn.Module):
config: DeepspeechConfig
@nn.compact
- def __call__(self, inputs, output_paddings, train):
+ def __call__(self, inputs, output_paddings, train, dropout_rate=None):
config = self.config
outputs = jnp.expand_dims(inputs, axis=-1)
diff --git a/algoperf/workloads/wmt/wmt_jax/models.py b/algoperf/workloads/wmt/wmt_jax/models.py
index a5b484320..b84fb6d96 100644
--- a/algoperf/workloads/wmt/wmt_jax/models.py
+++ b/algoperf/workloads/wmt/wmt_jax/models.py
@@ -236,7 +236,7 @@ def __call__(self, inputs, encoder_mask=None, dropout_rate=None):
# MLP block.
y = nn.LayerNorm(dtype=cfg.dtype)(x) if pre_ln else x
- y = MlpBlock(config=cfg)(y)
+ y = MlpBlock(config=cfg)(y, dropout_rate=dropout_rate)
return x + y if pre_ln else nn.LayerNorm(dtype=cfg.dtype)(x + y)
@@ -324,7 +324,7 @@ def __call__(
# MLP block.
z = nn.LayerNorm(dtype=cfg.dtype)(y) if pre_ln else y
- z = MlpBlock(config=cfg)(z)
+ z = MlpBlock(config=cfg)(z, dropout_rate=dropout_rate)
return y + z if pre_ln else nn.LayerNorm(dtype=cfg.dtype)(y + z)
@@ -382,7 +382,7 @@ def __call__(
# Input Encoder
for lyr in range(cfg.num_layers):
- x = Encoder1DBlock(config=cfg, name=f"encoderblock_{lyr}")(x, encoder_mask)
+ x = Encoder1DBlock(config=cfg, name=f"encoderblock_{lyr}")(x, encoder_mask, dropout_rate)
encoded = (
nn.LayerNorm(dtype=cfg.dtype, name="encoder_layernorm")(x)
@@ -464,6 +464,7 @@ def __call__(
encoded,
decoder_mask=decoder_mask,
encoder_decoder_mask=encoder_decoder_mask,
+ dropout_rate=dropout_rate,
)
y = (
nn.LayerNorm(dtype=cfg.dtype, name="encoderdecoder_layernorm")(y)
@@ -503,7 +504,7 @@ def setup(self):
self.encoder = Encoder(config=cfg, shared_embedding=self.shared_embedding)
self.decoder = Decoder(config=cfg, shared_embedding=self.shared_embedding)
- def encode(self, inputs, inputs_positions=None, inputs_segmentation=None):
+ def encode(self, inputs, inputs_positions=None, inputs_segmentation=None, dropout_rate=None):
"""Applies Transformer encoder-branch on the inputs.
Args:
@@ -528,7 +529,7 @@ def encode(self, inputs, inputs_positions=None, inputs_segmentation=None):
jnp.equal,
dtype=cfg.dtype))
return self.encoder(
- inputs, inputs_positions=inputs_positions, encoder_mask=encoder_mask)
+ inputs, inputs_positions=inputs_positions, encoder_mask=encoder_mask, dropout_rate=dropout_rate)
def decode(
self,
@@ -595,7 +596,8 @@ def __call__(self,
inputs_positions=None,
targets_positions=None,
inputs_segmentation=None,
- targets_segmentation=None):
+ targets_segmentation=None,
+ dropout_rate=None):
"""Applies Transformer model on the inputs.
Args:
@@ -612,7 +614,8 @@ def __call__(self,
encoded = self.encode(
inputs,
inputs_positions=inputs_positions,
- inputs_segmentation=inputs_segmentation)
+ inputs_segmentation=inputs_segmentation,
+ dropout_rate=dropout_rate)
return self.decode(
encoded,
@@ -620,4 +623,5 @@ def __call__(self,
targets,
targets_positions=targets_positions,
inputs_segmentation=inputs_segmentation,
- targets_segmentation=targets_segmentation)
+ targets_segmentation=targets_segmentation,
+ dropout_rate=dropout_rate)
From e36d29432a960283f9a44baf745644c5c3ddbdaa Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Thu, 29 May 2025 23:32:41 +0000
Subject: [PATCH 008/123] reformatting and dropout fixes to fastmri and vit
---
.../criteo1tb/criteo1tb_jax/models.py | 7 +-
.../workloads/fastmri/fastmri_jax/models.py | 18 +-
.../workloads/fastmri/fastmri_jax/workload.py | 21 +-
.../imagenet_vit/imagenet_jax/models.py | 75 ++-
.../imagenet_vit/imagenet_jax/workload.py | 22 +-
.../librispeech_jax/models.py | 4 +-
algoperf/workloads/wmt/wmt_jax/models.py | 525 +++++++++---------
7 files changed, 362 insertions(+), 310 deletions(-)
diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py
index c56748bb1..0b2126915 100644
--- a/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py
+++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py
@@ -7,6 +7,7 @@
from algoperf.jax_utils import Dropout
+
class DLRMResNet(nn.Module):
"""Define a DLRMResNet model.
@@ -30,7 +31,7 @@ class DLRMResNet(nn.Module):
@nn.compact
def __call__(self, x, train, dropout_rate=None):
if not dropout_rate:
- dropout_rate=self.dropout_rate
+ dropout_rate = self.dropout_rate
bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1)
cat_features = jnp.asarray(cat_features, dtype=jnp.int32)
@@ -157,7 +158,7 @@ class DlrmSmall(nn.Module):
def __call__(self, x, train, dropout_rate=None):
if not dropout_rate:
dropout_rate = self.dropout_rate
-
+
bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1)
cat_features = jnp.asarray(cat_features, dtype=jnp.int32)
@@ -219,6 +220,6 @@ def scaled_init(key, shape, dtype=jnp.float_):
if (dropout_rate is not None and dropout_rate > 0.0 and
layer_idx == num_layers_top - 2):
top_mlp_input = Dropout(deterministic=not train)(
- top_mlp_input, rate=dropout_rate)
+ top_mlp_input, rate=dropout_rate)
logits = top_mlp_input
return logits
diff --git a/algoperf/workloads/fastmri/fastmri_jax/models.py b/algoperf/workloads/fastmri/fastmri_jax/models.py
index 3d5460c18..7ecca2add 100644
--- a/algoperf/workloads/fastmri/fastmri_jax/models.py
+++ b/algoperf/workloads/fastmri/fastmri_jax/models.py
@@ -21,6 +21,7 @@
from algoperf.jax_utils import Dropout
+
def _instance_norm2d(x, axes, epsilon=1e-5):
# promote x to at least float32, this avoids half precision computation
# but preserves double or complex floating points
@@ -57,13 +58,13 @@ class UNet(nn.Module):
num_channels: int = 32
num_pool_layers: int = 4
out_channels = 1
- dropout_rate: Optional[float] = 0.0 # If None, defaults to 0.0.
+ dropout_rate: float = 0.0
use_tanh: bool = False
use_layer_norm: bool = False
@nn.compact
def __call__(self, x, train=True, dropout_rate=None):
- if not dropout_rate:
+ if dropout_rate is None:
dropout_rate = self.dropout_rate
# pylint: disable=invalid-name
@@ -138,7 +139,7 @@ class ConvBlock(nn.Module):
dropout_rate: Dropout probability.
"""
out_channels: int
- dropout_rate: float
+ dropout_rate: float = 0.0
use_tanh: bool
use_layer_norm: bool
@@ -152,8 +153,8 @@ def __call__(self, x, train=True, dropout_rate=None):
Returns:
jnp.array: Output tensor of shape `(N, H, W, out_channels)`.
"""
- if not dropout_rate:
- dropout_rate=self.dropout_rate
+ if dropout_rate is None:
+ dropout_rate = self.dropout_rate
x = nn.Conv(
features=self.out_channels,
kernel_size=(3, 3),
@@ -174,8 +175,9 @@ def __call__(self, x, train=True, dropout_rate=None):
x = activation_fn(x)
# Ref code uses dropout2d which applies the same mask for the entire channel
# Replicated by using broadcast dims to have the same filter on HW
- x = Dropout(broadcast_dims=(1, 2), deterministic=not train)(
- x, rate=dropout_rate )
+ x = Dropout(
+ dropout_rate, broadcast_dims=(1, 2), deterministic=not train)(
+ x, rate=dropout_rate)
x = nn.Conv(
features=self.out_channels,
kernel_size=(3, 3),
@@ -188,7 +190,7 @@ def __call__(self, x, train=True, dropout_rate=None):
x = _instance_norm2d(x, (1, 2))
x = activation_fn(x)
x = Dropout(
- broadcast_dims=(1, 2), deterministic=not train)(
+ dropout_rate, broadcast_dims=(1, 2), deterministic=not train)(
x, rate=dropout_rate)
return x
diff --git a/algoperf/workloads/fastmri/fastmri_jax/workload.py b/algoperf/workloads/fastmri/fastmri_jax/workload.py
index 1156cf30a..17ce6b442 100644
--- a/algoperf/workloads/fastmri/fastmri_jax/workload.py
+++ b/algoperf/workloads/fastmri/fastmri_jax/workload.py
@@ -26,12 +26,21 @@ def init_model_fn(
"""aux_dropout_rate is unused."""
del aux_dropout_rate
fake_batch = jnp.zeros((13, 320, 320))
- self._model = UNet(
- num_pool_layers=self.num_pool_layers,
- num_channels=self.num_channels,
- use_tanh=self.use_tanh,
- use_layer_norm=self.use_layer_norm,
- dropout_rate=dropout_rate)
+ if dropout_rate is None:
+ self._model = UNet(
+ num_pool_layers=self.num_pool_layers,
+ num_channels=self.num_channels,
+ use_tanh=self.use_tanh,
+ use_layer_norm=self.use_layer_norm,
+ )
+ else:
+ self._model = UNet(
+ num_pool_layers=self.num_pool_layers,
+ num_channels=self.num_channels,
+ use_tanh=self.use_tanh,
+ use_layer_norm=self.use_layer_norm,
+ dropout_rate=dropout_rate)
+
params_rng, dropout_rng = jax.random.split(rng)
variables = jax.jit(
self._model.init)({'params': params_rng, 'dropout': dropout_rng},
diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py
index 10a90f37d..227f7c297 100644
--- a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py
+++ b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py
@@ -39,9 +39,12 @@ class MlpBlock(nn.Module):
dropout_rate: float = 0.0
@nn.compact
- def __call__(self, x: spec.Tensor, train: bool = True, dropout_rate=None) -> spec.Tensor:
+ def __call__(self,
+ x: spec.Tensor,
+ train: bool = True,
+ dropout_rate=None) -> spec.Tensor:
"""Applies Transformer MlpBlock module."""
- if not dropout_rate:
+ if dropout_rate is None:
dropout_rate = self.dropout_rate
inits = {
@@ -57,7 +60,7 @@ def __call__(self, x: spec.Tensor, train: bool = True, dropout_rate=None) -> spe
y = nn.Dense(self.mlp_dim, **inits)(x)
x = x * y
- x = Dropout()(x, train, rate=dropout_rate)
+ x = Dropout(dropout_rate)(x, train, rate=dropout_rate)
x = nn.Dense(d, **inits)(x)
return x
@@ -71,9 +74,12 @@ class Encoder1DBlock(nn.Module):
dropout_rate: float = 0.0
@nn.compact
- def __call__(self, x: spec.Tensor, train: bool = True, dropout_rate=dropout_rate) -> spec.Tensor:
- if not dropout_rate:
- dropout_rate=self.dropout_rate
+ def __call__(self,
+ x: spec.Tensor,
+ train: bool = True,
+ dropout_rate=dropout_rate) -> spec.Tensor:
+ if dropout_rate is None:
+ dropout_rate = self.dropout_rate
if not self.use_post_layer_norm:
y = nn.LayerNorm(name='LayerNorm_0')(x)
@@ -83,15 +89,14 @@ def __call__(self, x: spec.Tensor, train: bool = True, dropout_rate=dropout_rate
deterministic=train,
name='MultiHeadDotProductAttention_1')(
y)
- y = Dropout()(y, train, dropout_rate=dropout_rate)
+ y = Dropout(dropout_rate)(y, train, dropout_rate=dropout_rate)
x = x + y
y = nn.LayerNorm(name='LayerNorm_2')(x)
y = MlpBlock(
- mlp_dim=self.mlp_dim,
- use_glu=self.use_glu,
- name='MlpBlock_3')(y, train, dropout_rate=dropout_rate)
- y = Dropout()(y, train, rate=dropout_rate)
+ mlp_dim=self.mlp_dim, use_glu=self.use_glu, name='MlpBlock_3')(
+ y, train, dropout_rate=dropout_rate)
+ y = Dropout(dropout_rate)(y, train, rate=dropout_rate)
x = x + y
else:
y = x
@@ -101,7 +106,7 @@ def __call__(self, x: spec.Tensor, train: bool = True, dropout_rate=dropout_rate
deterministic=train,
name='MultiHeadDotProductAttention_1')(
y)
- y = Dropout()(y, train, rate=dropout_rate)
+ y = Dropout(dropout_rate)(y, train, rate=dropout_rate)
x = x + y
x = nn.LayerNorm(name='LayerNorm_0')(x)
@@ -109,8 +114,10 @@ def __call__(self, x: spec.Tensor, train: bool = True, dropout_rate=dropout_rate
y = MlpBlock(
mlp_dim=self.mlp_dim,
use_glu=self.use_glu,
- name='MlpBlock_3')(y, train, dropout_rate=dropout_rate)
- y = Dropout()(y, train)(rate=dropout_rate)
+ name='MlpBlock_3',
+ dropout_rate=dropout_rate)(
+ y, train, dropout_rate=dropout_rate)
+ y = Dropout(dropout_rate)(y, train)(rate=dropout_rate)
x = x + y
x = nn.LayerNorm(name='LayerNorm_2')(x)
@@ -127,9 +134,12 @@ class Encoder(nn.Module):
use_post_layer_norm: bool = False
@nn.compact
- def __call__(self, x: spec.Tensor, train: bool = True, dropout_rate=None) -> spec.Tensor:
- if not dropout_rate:
- dropout_rate=self.dropout_rate
+ def __call__(self,
+ x: spec.Tensor,
+ train: bool = True,
+ dropout_rate=None) -> spec.Tensor:
+ if dropout_rate is None:
+ dropout_rate = self.dropout_rate
# Input Encoder
for lyr in range(self.depth):
@@ -139,7 +149,8 @@ def __call__(self, x: spec.Tensor, train: bool = True, dropout_rate=None) -> spe
num_heads=self.num_heads,
use_glu=self.use_glu,
use_post_layer_norm=self.use_post_layer_norm,
- )(dropout_rate=dropout_rate)
+ dropout_rate=dropout_rate)(
+ dropout_rate=dropout_rate)
x = block(x, train)
if not self.use_post_layer_norm:
return nn.LayerNorm(name='encoder_layernorm')(x)
@@ -151,9 +162,12 @@ class MAPHead(nn.Module):
"""Multihead Attention Pooling."""
mlp_dim: Optional[int] = None # Defaults to 4x input dim
num_heads: int = 12
+ dropout_rate: 0.0
@nn.compact
- def __call__(self, x):
+ def __call__(self, x, dropout_rate=None):
+ if dropout_rate is None:
+ dropout_rate = self.dropout_rate
n, _, d = x.shape
probe = self.param('probe',
nn.initializers.xavier_uniform(), (1, 1, d),
@@ -166,7 +180,7 @@ def __call__(self, x):
kernel_init=nn.initializers.xavier_uniform())(probe, x)
y = nn.LayerNorm()(x)
- x = x + MlpBlock(mlp_dim=self.mlp_dim)(y)
+ x = x + MlpBlock(mlp_dim=self.mlp_dim, dropout_rate=dropout_rate)(y)
return x[:, 0]
@@ -180,7 +194,7 @@ class ViT(nn.Module):
mlp_dim: Optional[int] = None # Defaults to 4x input dim.
num_heads: int = 12
rep_size: Union[int, bool] = True
- dropout_rate: Optional[float] = 0.0 # If None, defaults to 0.0.
+ dropout_rate: Optional[float] = 0.0
reinit: Optional[Sequence[str]] = None
head_zeroinit: bool = True
use_glu: bool = False
@@ -194,8 +208,12 @@ def get_posemb(self,
return posemb_sincos_2d(*seqshape, width, dtype=dtype)
@nn.compact
- def __call__(self, x: spec.Tensor, *, train: bool = False, dropout_rate=None) -> spec.Tensor:
- if not dropout_rate:
+ def __call__(self,
+ x: spec.Tensor,
+ *,
+ train: bool = False,
+ dropout_rate=None) -> spec.Tensor:
+ if dropout_rate is None:
dropout_rate = self.dropout_rate
# Patch extraction
x = nn.Conv(
@@ -212,7 +230,7 @@ def __call__(self, x: spec.Tensor, *, train: bool = False, dropout_rate=None) ->
# Add posemb before adding extra token.
x = x + self.get_posemb((h, w), c, x.dtype)
- x = Dropout()(x, not train, rate=dropout_rate)
+ x = Dropout(dropout_rate)(x, not train, rate=dropout_rate)
x = Encoder(
depth=self.depth,
@@ -220,11 +238,16 @@ def __call__(self, x: spec.Tensor, *, train: bool = False, dropout_rate=None) ->
num_heads=self.num_heads,
use_glu=self.use_glu,
use_post_layer_norm=self.use_post_layer_norm,
- name='Transformer')(
+ name='Transformer',
+ dropout_rate=dropout_rate)(
x, train=not train, dropout_rate=dropout_rate)
if self.use_map:
- x = MAPHead(num_heads=self.num_heads, mlp_dim=self.mlp_dim)(x)
+ x = MAPHead(
+ num_heads=self.num_heads,
+ mlp_dim=self.mlp_dim,
+ dropout_rate=dropout_rate)(
+ x, dropout_rate=dropout_rate)
else:
x = jnp.mean(x, axis=1)
diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py
index 35a6c46be..5d07b5ff8 100644
--- a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py
+++ b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py
@@ -36,13 +36,21 @@ def init_model_fn(
dropout_rate: Optional[float] = None,
aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
del aux_dropout_rate
- self._model = models.ViT(
- dropout_rate=dropout_rate,
- num_classes=self._num_classes,
- use_glu=self.use_glu,
- use_post_layer_norm=self.use_post_layer_norm,
- use_map=self.use_map,
- **decode_variant('S/16'))
+ if dropout_rate is None:
+ self._model = models.ViT(
+ num_classes=self._num_classes,
+ use_glu=self.use_glu,
+ use_post_layer_norm=self.use_post_layer_norm,
+ use_map=self.use_map,
+ **decode_variant('S/16'))
+ else:
+ self._model = models.ViT(
+ dropout_rate=dropout_rate,
+ num_classes=self._num_classes,
+ use_glu=self.use_glu,
+ use_post_layer_norm=self.use_post_layer_norm,
+ use_map=self.use_map,
+ **decode_variant('S/16'))
params, model_state = self.initialized(rng, self._model)
self._param_shapes = param_utils.jax_param_shapes(params)
self._param_types = param_utils.jax_param_types(self._param_shapes)
diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py
index 003bf4ea7..4cdb02ee1 100644
--- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py
+++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py
@@ -111,9 +111,7 @@ def __call__(self, inputs, output_paddings, train, dropout_rate=None):
input_dropout_rate = 0.1
else:
input_dropout_rate = config.input_dropout_rate
- outputs = Dropout(
- rate=input_dropout_rate, deterministic=not train)(
- outputs)
+ outputs = Dropout(rate=input_dropout_rate, deterministic=not train)(outputs)
return outputs, output_paddings
diff --git a/algoperf/workloads/wmt/wmt_jax/models.py b/algoperf/workloads/wmt/wmt_jax/models.py
index b84fb6d96..54a917a09 100644
--- a/algoperf/workloads/wmt/wmt_jax/models.py
+++ b/algoperf/workloads/wmt/wmt_jax/models.py
@@ -140,64 +140,68 @@ def __call__(self, inputs, inputs_positions=None):
class MlpBlock(nn.Module):
- """Transformer MLP / feed-forward block.
+ """Transformer MLP / feed-forward block.
Attributes:
config: TransformerConfig dataclass containing hyperparameters.
out_dim: optionally specify out dimension.
"""
- config: TransformerConfig
- out_dim: Optional[int] = None
-
- @nn.compact
- def __call__(self, inputs, dropout_rate=None):
- """Applies Transformer MlpBlock module."""
- cfg = self.config
- actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim
- x = nn.Dense(
- cfg.mlp_dim,
- dtype=cfg.dtype,
- kernel_init=cfg.kernel_init,
- bias_init=cfg.bias_init,
- )(inputs)
- x = cfg.activation(x)
- if cfg.glu:
- y = nn.Dense(
- cfg.mlp_dim,
- dtype=cfg.dtype,
- kernel_init=cfg.kernel_init,
- bias_init=cfg.bias_init,
- )(inputs)
- x = x * y
- if dropout_rate is None:
- if cfg.dropout_rate is None:
- dropout_rate = 0.1
- else:
- dropout_rate = cfg.dropout_rate
- x = Dropout()(x, rate=dropout_rate, deterministic=cfg.deterministic)
- output = nn.Dense(
- actual_out_dim,
- dtype=cfg.dtype,
- kernel_init=cfg.kernel_init,
- bias_init=cfg.bias_init,
- )(x)
- output = Dropout()(output, rate=dropout_rate, deterministic=cfg.deterministic)
- return output
+ config: TransformerConfig
+ out_dim: Optional[int] = None
+
+ @nn.compact
+ def __call__(self, inputs, dropout_rate=None):
+ """Applies Transformer MlpBlock module."""
+ cfg = self.config
+ actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim
+ x = nn.Dense(
+ cfg.mlp_dim,
+ dtype=cfg.dtype,
+ kernel_init=cfg.kernel_init,
+ bias_init=cfg.bias_init,
+ )(
+ inputs)
+ x = cfg.activation(x)
+ if cfg.glu:
+ y = nn.Dense(
+ cfg.mlp_dim,
+ dtype=cfg.dtype,
+ kernel_init=cfg.kernel_init,
+ bias_init=cfg.bias_init,
+ )(
+ inputs)
+ x = x * y
+ if dropout_rate is None:
+ if cfg.dropout_rate is None:
+ dropout_rate = 0.1
+ else:
+ dropout_rate = cfg.dropout_rate
+ x = Dropout()(x, rate=dropout_rate, deterministic=cfg.deterministic)
+ output = nn.Dense(
+ actual_out_dim,
+ dtype=cfg.dtype,
+ kernel_init=cfg.kernel_init,
+ bias_init=cfg.bias_init,
+ )(
+ x)
+ output = Dropout()(
+ output, rate=dropout_rate, deterministic=cfg.deterministic)
+ return output
class Encoder1DBlock(nn.Module):
- """Transformer encoder layer.
+ """Transformer encoder layer.
Attributes:
config: TransformerConfig dataclass containing hyperparameters.
"""
- config: TransformerConfig
+ config: TransformerConfig
- @nn.compact
- def __call__(self, inputs, encoder_mask=None, dropout_rate=None):
- """Applies Encoder1DBlock module.
+ @nn.compact
+ def __call__(self, inputs, encoder_mask=None, dropout_rate=None):
+ """Applies Encoder1DBlock module.
Args:
inputs: input data.
@@ -206,60 +210,60 @@ def __call__(self, inputs, encoder_mask=None, dropout_rate=None):
Returns:
output after transformer encoder block.
"""
- cfg = self.config
- pre_ln = cfg.pre_ln
-
- # Attention block.
- assert inputs.ndim == 3
- x = nn.LayerNorm(dtype=cfg.dtype)(inputs) if pre_ln else inputs
- if dropout_rate is None:
- if cfg.attention_dropout_rate is None:
- dropout_rate = 0.1
- else:
- dropout_rate = cfg.dropout_rate
- x = nn.MultiHeadDotProductAttention(
- num_heads=cfg.num_heads,
- dtype=cfg.dtype,
- qkv_features=cfg.qkv_dim,
- kernel_init=cfg.kernel_init,
- bias_init=cfg.bias_init,
- use_bias=False,
- broadcast_dropout=False,
- dropout_rate=dropout_rate,
- deterministic=cfg.deterministic,
- )(cfg.attention_temp * x, x, mask=encoder_mask)
-
- x = Dropout()(x, deterministic=cfg.deterministic, rate=dropout_rate)
- x = x + inputs
- if not pre_ln:
- x = nn.LayerNorm(dtype=cfg.dtype)(x)
-
- # MLP block.
- y = nn.LayerNorm(dtype=cfg.dtype)(x) if pre_ln else x
- y = MlpBlock(config=cfg)(y, dropout_rate=dropout_rate)
-
- return x + y if pre_ln else nn.LayerNorm(dtype=cfg.dtype)(x + y)
+ cfg = self.config
+ pre_ln = cfg.pre_ln
+
+ # Attention block.
+ assert inputs.ndim == 3
+ x = nn.LayerNorm(dtype=cfg.dtype)(inputs) if pre_ln else inputs
+ if dropout_rate is None:
+ if cfg.attention_dropout_rate is None:
+ dropout_rate = 0.1
+ else:
+ dropout_rate = cfg.dropout_rate
+ x = nn.MultiHeadDotProductAttention(
+ num_heads=cfg.num_heads,
+ dtype=cfg.dtype,
+ qkv_features=cfg.qkv_dim,
+ kernel_init=cfg.kernel_init,
+ bias_init=cfg.bias_init,
+ use_bias=False,
+ broadcast_dropout=False,
+ dropout_rate=dropout_rate,
+ deterministic=cfg.deterministic,
+ )(cfg.attention_temp * x, x, mask=encoder_mask)
+
+ x = Dropout()(x, deterministic=cfg.deterministic, rate=dropout_rate)
+ x = x + inputs
+ if not pre_ln:
+ x = nn.LayerNorm(dtype=cfg.dtype)(x)
+
+ # MLP block.
+ y = nn.LayerNorm(dtype=cfg.dtype)(x) if pre_ln else x
+ y = MlpBlock(config=cfg)(y, dropout_rate=dropout_rate)
+
+ return x + y if pre_ln else nn.LayerNorm(dtype=cfg.dtype)(x + y)
class EncoderDecoder1DBlock(nn.Module):
- """Transformer encoder-decoder layer.
+ """Transformer encoder-decoder layer.
Attributes:
config: TransformerConfig dataclass containing hyperparameters.
"""
- config: TransformerConfig
+ config: TransformerConfig
- @nn.compact
- def __call__(
- self,
- targets,
- encoded,
- decoder_mask=None,
- encoder_decoder_mask=None,
- dropout_rate=None,
- ):
- """Applies EncoderDecoder1DBlock module.
+ @nn.compact
+ def __call__(
+ self,
+ targets,
+ encoded,
+ decoder_mask=None,
+ encoder_decoder_mask=None,
+ dropout_rate=None,
+ ):
+ """Applies EncoderDecoder1DBlock module.
Args:
targets: input data for decoder
@@ -270,81 +274,83 @@ def __call__(
Returns:
output after transformer encoder-decoder block.
"""
- cfg = self.config
- pre_ln = cfg.pre_ln
-
- # Decoder block.
- assert targets.ndim == 3
- x = nn.LayerNorm(dtype=cfg.dtype)(targets) if pre_ln else targets
-
- if dropout_rate is None:
- if cfg.attention_dropout_rate is None:
- dropout_rate = 0.1
- else:
- dropout_rate = cfg.dropout_rate
- x = nn.MultiHeadDotProductAttention(
- num_heads=cfg.num_heads,
- dtype=cfg.dtype,
- qkv_features=cfg.qkv_dim,
- kernel_init=cfg.kernel_init,
- bias_init=cfg.bias_init,
- use_bias=False,
- broadcast_dropout=False,
- dropout_rate=dropout_rate,
- deterministic=cfg.deterministic,
- decode=cfg.decode,
- )(cfg.attention_temp * x, x, mask=decoder_mask)
- if cfg.dropout_rate is None:
- dropout_rate = 0.1
- else:
- dropout_rate = cfg.dropout_rate
- x = Dropout()(x, deterministic=cfg.deterministic, rate=dropout_rate)
- x = x + targets
- if not pre_ln:
- x = nn.LayerNorm(dtype=cfg.dtype)(x)
-
- # Encoder-Decoder block.
- y = nn.LayerNorm(dtype=cfg.dtype)(x) if pre_ln else x
- y = nn.MultiHeadDotProductAttention(
- num_heads=cfg.num_heads,
- dtype=cfg.dtype,
- qkv_features=cfg.qkv_dim,
- kernel_init=cfg.kernel_init,
- bias_init=cfg.bias_init,
- use_bias=False,
- broadcast_dropout=False,
- dropout_rate=dropout_rate,
- deterministic=cfg.deterministic,
- )(cfg.attention_temp * y, encoded, mask=encoder_decoder_mask)
-
- y = Dropout()(y, deterministic=cfg.deterministic, rate=dropout_rate)
- y = y + x
- if not pre_ln:
- y = nn.LayerNorm(dtype=cfg.dtype)(y)
-
- # MLP block.
- z = nn.LayerNorm(dtype=cfg.dtype)(y) if pre_ln else y
- z = MlpBlock(config=cfg)(z, dropout_rate=dropout_rate)
-
- return y + z if pre_ln else nn.LayerNorm(dtype=cfg.dtype)(y + z)
+ cfg = self.config
+ pre_ln = cfg.pre_ln
+
+ # Decoder block.
+ assert targets.ndim == 3
+ x = nn.LayerNorm(dtype=cfg.dtype)(targets) if pre_ln else targets
+
+ if dropout_rate is None:
+ if cfg.attention_dropout_rate is None:
+ dropout_rate = 0.1
+ else:
+ dropout_rate = cfg.dropout_rate
+ x = nn.MultiHeadDotProductAttention(
+ num_heads=cfg.num_heads,
+ dtype=cfg.dtype,
+ qkv_features=cfg.qkv_dim,
+ kernel_init=cfg.kernel_init,
+ bias_init=cfg.bias_init,
+ use_bias=False,
+ broadcast_dropout=False,
+ dropout_rate=dropout_rate,
+ deterministic=cfg.deterministic,
+ decode=cfg.decode,
+ )(cfg.attention_temp * x, x, mask=decoder_mask)
+ if cfg.dropout_rate is None:
+ dropout_rate = 0.1
+ else:
+ dropout_rate = cfg.dropout_rate
+ x = Dropout()(x, deterministic=cfg.deterministic, rate=dropout_rate)
+ x = x + targets
+ if not pre_ln:
+ x = nn.LayerNorm(dtype=cfg.dtype)(x)
+
+ # Encoder-Decoder block.
+ y = nn.LayerNorm(dtype=cfg.dtype)(x) if pre_ln else x
+ y = nn.MultiHeadDotProductAttention(
+ num_heads=cfg.num_heads,
+ dtype=cfg.dtype,
+ qkv_features=cfg.qkv_dim,
+ kernel_init=cfg.kernel_init,
+ bias_init=cfg.bias_init,
+ use_bias=False,
+ broadcast_dropout=False,
+ dropout_rate=dropout_rate,
+ deterministic=cfg.deterministic,
+ )(cfg.attention_temp * y, encoded, mask=encoder_decoder_mask)
+
+ y = Dropout()(y, deterministic=cfg.deterministic, rate=dropout_rate)
+ y = y + x
+ if not pre_ln:
+ y = nn.LayerNorm(dtype=cfg.dtype)(y)
+
+ # MLP block.
+ z = nn.LayerNorm(dtype=cfg.dtype)(y) if pre_ln else y
+ z = MlpBlock(config=cfg)(z, dropout_rate=dropout_rate)
+
+ return y + z if pre_ln else nn.LayerNorm(dtype=cfg.dtype)(y + z)
class Encoder(nn.Module):
- """Transformer Model Encoder for sequence to sequence translation.
+ """Transformer Model Encoder for sequence to sequence translation.
Attributes:
config: TransformerConfig dataclass containing hyperparameters.
shared_embedding: a shared embedding layer to use.
"""
- config: TransformerConfig
- shared_embedding: Any = None
+ config: TransformerConfig
+ shared_embedding: Any = None
- @nn.compact
- def __call__(
- self, inputs, inputs_positions=None, encoder_mask=None, dropout_rate=None
- ):
- """Applies Transformer model on the inputs.
+ @nn.compact
+ def __call__(self,
+ inputs,
+ inputs_positions=None,
+ encoder_mask=None,
+ dropout_rate=None):
+ """Applies Transformer model on the inputs.
Args:
inputs: input data
@@ -354,67 +360,66 @@ def __call__(
Returns:
output of a transformer encoder.
"""
- cfg = self.config
- assert inputs.ndim == 2 # (batch, len)
-
- # Input Embedding
- if self.shared_embedding is None:
- input_embed = nn.Embed(
- num_embeddings=cfg.vocab_size,
- features=cfg.emb_dim,
- embedding_init=nn.initializers.normal(stddev=1.0),
- )
- else:
- input_embed = self.shared_embedding
- x = inputs.astype("int32")
- x = input_embed(x)
- x = AddPositionEmbs(config=cfg, decode=False, name="posembed_input")(
- x, inputs_positions=inputs_positions
- )
- if dropout_rate is None:
- if cfg.dropout_rate is None:
- dropout_rate = 0.1
- else:
- dropout_rate = cfg.dropout_rate
- x = Dropout()(x, deterministic=cfg.deterministic, rate=dropout_rate)
-
- x = x.astype(cfg.dtype)
-
- # Input Encoder
- for lyr in range(cfg.num_layers):
- x = Encoder1DBlock(config=cfg, name=f"encoderblock_{lyr}")(x, encoder_mask, dropout_rate)
-
- encoded = (
- nn.LayerNorm(dtype=cfg.dtype, name="encoder_layernorm")(x)
- if cfg.pre_ln
- else x
- )
-
- return encoded
+ cfg = self.config
+ assert inputs.ndim == 2 # (batch, len)
+
+ # Input Embedding
+ if self.shared_embedding is None:
+ input_embed = nn.Embed(
+ num_embeddings=cfg.vocab_size,
+ features=cfg.emb_dim,
+ embedding_init=nn.initializers.normal(stddev=1.0),
+ )
+ else:
+ input_embed = self.shared_embedding
+ x = inputs.astype("int32")
+ x = input_embed(x)
+ x = AddPositionEmbs(
+ config=cfg, decode=False, name="posembed_input")(
+ x, inputs_positions=inputs_positions)
+ if dropout_rate is None:
+ if cfg.dropout_rate is None:
+ dropout_rate = 0.1
+ else:
+ dropout_rate = cfg.dropout_rate
+ x = Dropout()(x, deterministic=cfg.deterministic, rate=dropout_rate)
+
+ x = x.astype(cfg.dtype)
+
+ # Input Encoder
+ for lyr in range(cfg.num_layers):
+ x = Encoder1DBlock(
+ config=cfg, name=f"encoderblock_{lyr}")(x, encoder_mask, dropout_rate)
+
+ encoded = (
+ nn.LayerNorm(dtype=cfg.dtype, name="encoder_layernorm")(x)
+ if cfg.pre_ln else x)
+
+ return encoded
class Decoder(nn.Module):
- """Transformer Model Decoder for sequence to sequence translation.
+ """Transformer Model Decoder for sequence to sequence translation.
Attributes:
config: TransformerConfig dataclass containing hyperparameters.
shared_embedding: a shared embedding layer to use.
"""
- config: TransformerConfig
- shared_embedding: Any = None
+ config: TransformerConfig
+ shared_embedding: Any = None
- @nn.compact
- def __call__(
- self,
- encoded,
- targets,
- targets_positions=None,
- decoder_mask=None,
- encoder_decoder_mask=None,
- dropout_rate=None,
- ):
- """Applies Transformer model on the inputs.
+ @nn.compact
+ def __call__(
+ self,
+ encoded,
+ targets,
+ targets_positions=None,
+ decoder_mask=None,
+ encoder_decoder_mask=None,
+ dropout_rate=None,
+ ):
+ """Applies Transformer model on the inputs.
Args:
encoded: encoded input data from encoder.
@@ -426,57 +431,56 @@ def __call__(
Returns:
output of a transformer decoder.
"""
- cfg = self.config
-
- assert encoded.ndim == 3 # (batch, len, depth)
- assert targets.ndim == 2 # (batch, len)
-
- # Target Embedding
- if self.shared_embedding is None:
- output_embed = nn.Embed(
- num_embeddings=cfg.vocab_size,
- features=cfg.emb_dim,
- embedding_init=nn.initializers.normal(stddev=1.0),
- )
- else:
- output_embed = self.shared_embedding
-
- y = targets.astype("int32")
- if not cfg.decode:
- y = shift_right(y)
- y = output_embed(y)
- y = AddPositionEmbs(config=cfg, decode=cfg.decode, name="posembed_output")(
- y, inputs_positions=targets_positions
- )
- if dropout_rate is None:
- if cfg.dropout_rate is None:
- dropout_rate = 0.1
- else:
- dropout_rate = cfg.dropout_rate
- y = Dropout()(y, deterministic=cfg.deterministic, rate=dropout_rate)
-
- y = y.astype(cfg.dtype)
-
- # Target-Input Decoder
- for lyr in range(cfg.num_layers):
- y = EncoderDecoder1DBlock(config=cfg, name=f"encoderdecoderblock_{lyr}")(
- y,
- encoded,
- decoder_mask=decoder_mask,
- encoder_decoder_mask=encoder_decoder_mask,
- dropout_rate=dropout_rate,
- )
- y = (
- nn.LayerNorm(dtype=cfg.dtype, name="encoderdecoder_layernorm")(y)
- if cfg.pre_ln
- else y
- )
-
- # Use the transpose of embedding matrix for logit transform.
- logits = output_embed.attend(y.astype(jnp.float32))
- # Correctly normalize pre-softmax logits for this shared case.
- logits = logits / jnp.sqrt(y.shape[-1])
- return logits
+ cfg = self.config
+
+ assert encoded.ndim == 3 # (batch, len, depth)
+ assert targets.ndim == 2 # (batch, len)
+
+ # Target Embedding
+ if self.shared_embedding is None:
+ output_embed = nn.Embed(
+ num_embeddings=cfg.vocab_size,
+ features=cfg.emb_dim,
+ embedding_init=nn.initializers.normal(stddev=1.0),
+ )
+ else:
+ output_embed = self.shared_embedding
+
+ y = targets.astype("int32")
+ if not cfg.decode:
+ y = shift_right(y)
+ y = output_embed(y)
+ y = AddPositionEmbs(
+ config=cfg, decode=cfg.decode, name="posembed_output")(
+ y, inputs_positions=targets_positions)
+ if dropout_rate is None:
+ if cfg.dropout_rate is None:
+ dropout_rate = 0.1
+ else:
+ dropout_rate = cfg.dropout_rate
+ y = Dropout()(y, deterministic=cfg.deterministic, rate=dropout_rate)
+
+ y = y.astype(cfg.dtype)
+
+ # Target-Input Decoder
+ for lyr in range(cfg.num_layers):
+ y = EncoderDecoder1DBlock(
+ config=cfg, name=f"encoderdecoderblock_{lyr}")(
+ y,
+ encoded,
+ decoder_mask=decoder_mask,
+ encoder_decoder_mask=encoder_decoder_mask,
+ dropout_rate=dropout_rate,
+ )
+ y = (
+ nn.LayerNorm(dtype=cfg.dtype, name="encoderdecoder_layernorm")(y)
+ if cfg.pre_ln else y)
+
+ # Use the transpose of embedding matrix for logit transform.
+ logits = output_embed.attend(y.astype(jnp.float32))
+ # Correctly normalize pre-softmax logits for this shared case.
+ logits = logits / jnp.sqrt(y.shape[-1])
+ return logits
class Transformer(nn.Module):
@@ -504,7 +508,11 @@ def setup(self):
self.encoder = Encoder(config=cfg, shared_embedding=self.shared_embedding)
self.decoder = Decoder(config=cfg, shared_embedding=self.shared_embedding)
- def encode(self, inputs, inputs_positions=None, inputs_segmentation=None, dropout_rate=None):
+ def encode(self,
+ inputs,
+ inputs_positions=None,
+ inputs_segmentation=None,
+ dropout_rate=None):
"""Applies Transformer encoder-branch on the inputs.
Args:
@@ -529,7 +537,10 @@ def encode(self, inputs, inputs_positions=None, inputs_segmentation=None, dropou
jnp.equal,
dtype=cfg.dtype))
return self.encoder(
- inputs, inputs_positions=inputs_positions, encoder_mask=encoder_mask, dropout_rate=dropout_rate)
+ inputs,
+ inputs_positions=inputs_positions,
+ encoder_mask=encoder_mask,
+ dropout_rate=dropout_rate)
def decode(
self,
From 363da8ac032c82b2ed8ac1c7ea64a25be243cbc6 Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Thu, 29 May 2025 23:39:39 +0000
Subject: [PATCH 009/123] dropout fix for criteo1tb jax
---
algoperf/workloads/criteo1tb/criteo1tb_jax/models.py | 11 ++++++-----
1 file changed, 6 insertions(+), 5 deletions(-)
diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py
index 0b2126915..b7af15208 100644
--- a/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py
+++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py
@@ -30,7 +30,7 @@ class DLRMResNet(nn.Module):
@nn.compact
def __call__(self, x, train, dropout_rate=None):
- if not dropout_rate:
+ if dropout_rate is None:
dropout_rate = self.dropout_rate
bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1)
@@ -93,7 +93,7 @@ def scaled_init(key, shape, dtype=jnp.float_):
top_mlp_input)
x = nn.relu(x)
if dropout_rate and layer_idx == num_layers_top - 2:
- x = Dropout(deterministic=not train)(x, rate=dropout_rate)
+ x = Dropout(dropout_rate, deterministic=not train)(x, rate=dropout_rate)
top_mlp_input += x
# In the DLRM model the last layer width is always 1. We can hardcode that
# below.
@@ -156,7 +156,7 @@ class DlrmSmall(nn.Module):
@nn.compact
def __call__(self, x, train, dropout_rate=None):
- if not dropout_rate:
+ if dropout_rate is None:
dropout_rate = self.dropout_rate
bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1)
@@ -219,7 +219,8 @@ def scaled_init(key, shape, dtype=jnp.float_):
top_mlp_input = nn.LayerNorm()(top_mlp_input)
if (dropout_rate is not None and dropout_rate > 0.0 and
layer_idx == num_layers_top - 2):
- top_mlp_input = Dropout(deterministic=not train)(
- top_mlp_input, rate=dropout_rate)
+ top_mlp_input = Dropout(
+ dropout_rate, deterministic=not train)(
+ top_mlp_input, rate=dropout_rate)
logits = top_mlp_input
return logits
From 341bf8996de4e4897eb9559124b1123c115dd62a Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Thu, 29 May 2025 23:42:00 +0000
Subject: [PATCH 010/123] dropout fix for criteo1tb jax
---
.../criteo1tb/criteo1tb_jax/workload.py | 29 +++++++++++++------
1 file changed, 20 insertions(+), 9 deletions(-)
diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py
index 91761e458..bad2f4390 100644
--- a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py
+++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py
@@ -82,15 +82,26 @@ def init_model_fn(
model_class = models.DLRMResNet
else:
model_class = models.DlrmSmall
- self._model = model_class(
- vocab_size=self.vocab_size,
- num_dense_features=self.num_dense_features,
- mlp_bottom_dims=self.mlp_bottom_dims,
- mlp_top_dims=self.mlp_top_dims,
- embed_dim=self.embed_dim,
- dropout_rate=dropout_rate,
- use_layer_norm=self.use_layer_norm,
- embedding_init_multiplier=self.embedding_init_multiplier)
+
+ if dropout_rate is None:
+ self._model = model_class(
+ vocab_size=self.vocab_size,
+ num_dense_features=self.num_dense_features,
+ mlp_bottom_dims=self.mlp_bottom_dims,
+ mlp_top_dims=self.mlp_top_dims,
+ embed_dim=self.embed_dim,
+ use_layer_norm=self.use_layer_norm,
+ embedding_init_multiplier=self.embedding_init_multiplier)
+ else:
+ self._model = model_class(
+ vocab_size=self.vocab_size,
+ num_dense_features=self.num_dense_features,
+ mlp_bottom_dims=self.mlp_bottom_dims,
+ mlp_top_dims=self.mlp_top_dims,
+ embed_dim=self.embed_dim,
+ dropout_rate=dropout_rate,
+ use_layer_norm=self.use_layer_norm,
+ embedding_init_multiplier=self.embedding_init_multiplier)
params_rng, dropout_rng = jax.random.split(rng)
init_fake_batch_size = 2
From f0c385bcec139050fe11bee01f0bc0ba0b9194d9 Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Thu, 29 May 2025 23:49:42 +0000
Subject: [PATCH 011/123] remove aux dropout option from conformer and from
init_model_fn signature for fastmri, vit and criteo
---
.../criteo1tb/criteo1tb_jax/workload.py | 2 --
.../workloads/fastmri/fastmri_jax/workload.py | 3 +-
.../imagenet_vit/imagenet_jax/workload.py | 4 +--
.../librispeech_jax/workload.py | 28 ++++++++++++-------
4 files changed, 20 insertions(+), 17 deletions(-)
diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py
index bad2f4390..e3864643b 100644
--- a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py
+++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py
@@ -73,11 +73,9 @@ def init_model_fn(
self,
rng: spec.RandomState,
dropout_rate: Optional[float] = None,
- aux_dropout_rate: Optional[float] = None,
tabulate: Optional[bool] = False,
) -> spec.ModelInitState:
"""Only dropout is used."""
- del aux_dropout_rate
if self.use_resnet:
model_class = models.DLRMResNet
else:
diff --git a/algoperf/workloads/fastmri/fastmri_jax/workload.py b/algoperf/workloads/fastmri/fastmri_jax/workload.py
index 17ce6b442..13ab5c1b8 100644
--- a/algoperf/workloads/fastmri/fastmri_jax/workload.py
+++ b/algoperf/workloads/fastmri/fastmri_jax/workload.py
@@ -22,9 +22,8 @@ def init_model_fn(
self,
rng: spec.RandomState,
dropout_rate: Optional[float] = None,
- aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
+ ) -> spec.ModelInitState:
"""aux_dropout_rate is unused."""
- del aux_dropout_rate
fake_batch = jnp.zeros((13, 320, 320))
if dropout_rate is None:
self._model = UNet(
diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py
index 5d07b5ff8..5107ed993 100644
--- a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py
+++ b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py
@@ -33,9 +33,7 @@ def initialized(self, key: spec.RandomState,
def init_model_fn(
self,
rng: spec.RandomState,
- dropout_rate: Optional[float] = None,
- aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
- del aux_dropout_rate
+ dropout_rate: Optional[float] = None) -> spec.ModelInitState:
if dropout_rate is None:
self._model = models.ViT(
num_classes=self._num_classes,
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py
index 39012a20d..4da70fc61 100644
--- a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py
+++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py
@@ -61,24 +61,32 @@ def init_model_fn(
self,
rng: spec.RandomState,
dropout_rate: Optional[float] = None,
- aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
+ ) -> spec.ModelInitState:
"""Conformer model init function.
- Here we use dropout_rate as *_residual_dropout_rate, and aux_dropout_rate as
+ Here we use dropout_rate as *_residual_dropout_rate, and for
input_dropout_rate.
"""
if self.use_gelu:
activation_function_name = 'gelu'
else:
activation_function_name = 'swish'
- model_config = models.ConformerConfig(
- attention_residual_dropout_rate=dropout_rate,
- feed_forward_residual_dropout_rate=dropout_rate,
- input_dropout_rate=aux_dropout_rate,
- use_specaug=self.use_specaug,
- attention_temperature=self.attention_temperature,
- use_post_layer_norm=self.use_post_layer_norm,
- activation_function_name=activation_function_name)
+ if dropout_rate is None:
+ model_config = models.ConformerConfig(
+ attention_residual_dropout_rate=dropout_rate,
+ feed_forward_residual_dropout_rate=dropout_rate,
+ input_dropout_rate=dropout_rate,
+ use_specaug=self.use_specaug,
+ attention_temperature=self.attention_temperature,
+ use_post_layer_norm=self.use_post_layer_norm,
+ activation_function_name=activation_function_name)
+ else:
+ model_config = models.ConformerConfig(
+ use_specaug=self.use_specaug,
+ attention_temperature=self.attention_temperature,
+ use_post_layer_norm=self.use_post_layer_norm,
+ activation_function_name=activation_function_name)
+
self._model = models.Conformer(model_config)
input_shape = [(320000,), (320000,)]
fake_input_batch = [np.zeros((2, *x), jnp.float32) for x in input_shape]
From 7af5c941d81a7e66e3128afaf5b49c6f2730c302 Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Sat, 31 May 2025 00:26:32 +0000
Subject: [PATCH 012/123] add dropout piping for conformer and deepspeech
---
.../workloads/fastmri/fastmri_jax/workload.py | 2 +-
.../librispeech_jax/models.py | 7 +--
.../librispeech_jax/workload.py | 2 +-
.../librispeech_jax/models.py | 34 +++++++++------
.../librispeech_jax/workload.py | 43 +++++++++++--------
algoperf/workloads/ogbg/ogbg_jax/models.py | 4 +-
algoperf/workloads/ogbg/ogbg_jax/workload.py | 26 ++++++-----
algoperf/workloads/wmt/wmt_jax/workload.py | 24 ++++++-----
8 files changed, 83 insertions(+), 59 deletions(-)
diff --git a/algoperf/workloads/fastmri/fastmri_jax/workload.py b/algoperf/workloads/fastmri/fastmri_jax/workload.py
index 13ab5c1b8..439b8d055 100644
--- a/algoperf/workloads/fastmri/fastmri_jax/workload.py
+++ b/algoperf/workloads/fastmri/fastmri_jax/workload.py
@@ -22,7 +22,7 @@ def init_model_fn(
self,
rng: spec.RandomState,
dropout_rate: Optional[float] = None,
- ) -> spec.ModelInitState:
+ ) -> spec.ModelInitState:
"""aux_dropout_rate is unused."""
fake_batch = jnp.zeros((13, 320, 320))
if dropout_rate is None:
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py
index 51e93acc9..29c349e11 100644
--- a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py
+++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py
@@ -38,13 +38,10 @@ class ConformerConfig:
num_attention_heads: int = 8
num_encoder_layers: int = 4
attention_dropout_rate: float = 0.0
- # If None, defaults to 0.1.
- attention_residual_dropout_rate: Optional[float] = 0.1
- # If None, defaults to 0.0.
+ attention_residual_dropout_rate: Optional[float] = 0.0
conv_residual_dropout_rate: Optional[float] = 0.0
feed_forward_dropout_rate: float = 0.0
- # If None, defaults to 0.1.
- feed_forward_residual_dropout_rate: Optional[float] = 0.1
+ feed_forward_residual_dropout_rate: Optional[float] = 0.0
convolution_kernel_size: int = 5
feed_forward_expansion_factor: int = 4
freq_mask_count: int = 2
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py
index 4da70fc61..a54f52c04 100644
--- a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py
+++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py
@@ -61,7 +61,7 @@ def init_model_fn(
self,
rng: spec.RandomState,
dropout_rate: Optional[float] = None,
- ) -> spec.ModelInitState:
+ ) -> spec.ModelInitState:
"""Conformer model init function.
Here we use dropout_rate as *_residual_dropout_rate, and for
diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py
index 4cdb02ee1..3ad31b532 100644
--- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py
+++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py
@@ -75,6 +75,9 @@ class Subsample(nn.Module):
@nn.compact
def __call__(self, inputs, output_paddings, train, dropout_rate=None):
config = self.config
+ if dropout_rate is None:
+ dropout_rate = config.dropout_rate
+
outputs = jnp.expand_dims(inputs, axis=-1)
outputs, output_paddings = Conv2dSubsampling(
@@ -111,7 +114,9 @@ def __call__(self, inputs, output_paddings, train, dropout_rate=None):
input_dropout_rate = 0.1
else:
input_dropout_rate = config.input_dropout_rate
- outputs = Dropout(rate=input_dropout_rate, deterministic=not train)(outputs)
+ outputs = Dropout(
+ rate=input_dropout_rate, deterministic=not train, rate=dropout_rate)(
+ outputs, rate=dropout_rate)
return outputs, output_paddings
@@ -187,7 +192,13 @@ class FeedForwardModule(nn.Module):
config: DeepspeechConfig
@nn.compact
- def __call__(self, inputs, input_paddings=None, train=False):
+ def __call__(self,
+ inputs,
+ input_paddings=None,
+ train=False,
+ dropout_rate=None):
+ if dropout_rate is None:
+ dropout_rate = self.config.feed_forward_dropout_rate
padding_mask = jnp.expand_dims(1 - input_paddings, -1)
config = self.config
@@ -211,12 +222,8 @@ def __call__(self, inputs, input_paddings=None, train=False):
inputs = nn.relu(inputs)
inputs *= padding_mask
- if config.feed_forward_dropout_rate is None:
- feed_forward_dropout_rate = 0.1
- else:
- feed_forward_dropout_rate = config.feed_forward_dropout_rate
- inputs = Dropout(rate=feed_forward_dropout_rate)(
- inputs, deterministic=not train)
+ inputs = Dropout(rate=dropout_rate)(
+ inputs, deterministic=not train, rate=dropout_rate)
return inputs
@@ -472,8 +479,10 @@ def setup(self):
)
@nn.compact
- def __call__(self, inputs, input_paddings, train):
+ def __call__(self, inputs, input_paddings, train, dropout_rate=None):
config = self.config
+ if dropout_rate is None:
+ dropout_rate = config.dropout_rate
outputs = inputs
output_paddings = input_paddings
@@ -493,7 +502,7 @@ def __call__(self, inputs, input_paddings, train):
# Subsample input by a factor of 4 by performing strided convolutions.
outputs, output_paddings = Subsample(
- config=config)(outputs, output_paddings, train)
+ config=config)(outputs, output_paddings, train, dropout_rate=dropout_rate)
# Run the lstm layers.
for _ in range(config.num_lstm_layers):
@@ -507,9 +516,8 @@ def __call__(self, inputs, input_paddings, train):
outputs = outputs + FeedForwardModule(config=self.config)(
outputs, output_paddings, train)
else:
- outputs = FeedForwardModule(config=self.config)(outputs,
- output_paddings,
- train)
+ outputs = FeedForwardModule(config=self.config)(
+ outputs, output_paddings, train, dropout_rate=dropout_rate)
# Run the decoder which in this case is a trivial projection layer.
if config.enable_decoder_layer_norm:
diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py
index d3b616f43..3c9a96f99 100644
--- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py
+++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py
@@ -18,24 +18,31 @@ class LibriSpeechDeepSpeechWorkload(LibriSpeechConformerWorkload):
def init_model_fn(
self,
rng: spec.RandomState,
- dropout_rate: Optional[float] = None,
- aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
+ dropout_rate: Optional[float] = None) -> spec.ModelInitState:
"""Deepspeech model init function.
-
- Here we use dropout_rate as feed_forward_dropout_rate, and aux_dropout_rate
- as input_dropout_rate.
"""
- model_config = models.DeepspeechConfig(
- feed_forward_dropout_rate=dropout_rate,
- use_specaug=self.use_specaug,
- input_dropout_rate=aux_dropout_rate,
- use_tanh=self.use_tanh,
- enable_residual_connections=self.enable_residual_connections,
- enable_decoder_layer_norm=self.enable_decoder_layer_norm,
- layernorm_everywhere=self.layernorm_everywhere,
- freq_mask_count=self.freq_mask_count,
- time_mask_count=self.time_mask_count,
- )
+ if dropout_rate is None:
+ model_config = models.DeepspeechConfig(
+ use_specaug=self.use_specaug,
+ use_tanh=self.use_tanh,
+ enable_residual_connections=self.enable_residual_connections,
+ enable_decoder_layer_norm=self.enable_decoder_layer_norm,
+ layernorm_everywhere=self.layernorm_everywhere,
+ freq_mask_count=self.freq_mask_count,
+ time_mask_count=self.time_mask_count,
+ )
+ else:
+ model_config = models.DeepspeechConfig(
+ feed_forward_dropout_rate=dropout_rate,
+ use_specaug=self.use_specaug,
+ input_dropout_rate=dropout_rate,
+ use_tanh=self.use_tanh,
+ enable_residual_connections=self.enable_residual_connections,
+ enable_decoder_layer_norm=self.enable_decoder_layer_norm,
+ layernorm_everywhere=self.layernorm_everywhere,
+ freq_mask_count=self.freq_mask_count,
+ time_mask_count=self.time_mask_count,
+ )
self._model = models.Deepspeech(model_config)
input_shape = [(320000,), (320000,)]
fake_input_batch = [np.zeros((2, *x), jnp.float32) for x in input_shape]
@@ -64,6 +71,7 @@ def model_fn(
rng: spec.RandomState,
update_batch_norm: bool,
use_running_average_bn: Optional[bool] = None
+ dropout_rate: Optional[bool] = None
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
variables = {'params': params, **model_state}
inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs']
@@ -75,7 +83,8 @@ def model_fn(
input_paddings,
train=True,
rngs={'dropout' : rng},
- mutable=['batch_stats'])
+ mutable=['batch_stats'],
+ dropout_rate=dropout_rate)
return (logits, logit_paddings), new_model_state
else:
logits, logit_paddings = self._model.apply(
diff --git a/algoperf/workloads/ogbg/ogbg_jax/models.py b/algoperf/workloads/ogbg/ogbg_jax/models.py
index 6ced9bef5..f6cb1c490 100644
--- a/algoperf/workloads/ogbg/ogbg_jax/models.py
+++ b/algoperf/workloads/ogbg/ogbg_jax/models.py
@@ -48,9 +48,9 @@ class GNN(nn.Module):
@nn.compact
def __call__(self, graph, train, dropout_rate=None):
- if not dropout_rate:
+ if dropout_rate is not None:
dropout_rate = self.dropout_rate
- dropout = Dropout(deterministic=not train, rate=dropout_rate)
+ dropout = Dropout(dropout_rate, deterministic=not train)(dropout_rate)
graph = graph._replace(
globals=jnp.zeros([graph.n_node.shape[0], self.num_outputs]))
diff --git a/algoperf/workloads/ogbg/ogbg_jax/workload.py b/algoperf/workloads/ogbg/ogbg_jax/workload.py
index e895d15a7..f7de3f982 100644
--- a/algoperf/workloads/ogbg/ogbg_jax/workload.py
+++ b/algoperf/workloads/ogbg/ogbg_jax/workload.py
@@ -20,18 +20,24 @@ class OgbgWorkload(BaseOgbgWorkload):
def init_model_fn(
self,
rng: spec.RandomState,
- dropout_rate: Optional[float] = None,
- aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
+ dropout_rate: Optional[float] = None) -> spec.ModelInitState:
"""aux_dropout_rate is unused."""
- del aux_dropout_rate
rng, params_rng, dropout_rng = jax.random.split(rng, 3)
- self._model = models.GNN(
- self._num_outputs,
- dropout_rate=dropout_rate,
- activation_fn_name=self.activation_fn_name,
- hidden_dims=self.hidden_dims,
- latent_dim=self.latent_dim,
- num_message_passing_steps=self.num_message_passing_steps)
+ if dropout_rate is None:
+ self._model = models.GNN(
+ self._num_outputs,
+ activation_fn_name=self.activation_fn_name,
+ hidden_dims=self.hidden_dims,
+ latent_dim=self.latent_dim,
+ num_message_passing_steps=self.num_message_passing_steps)
+ else:
+ self._model = models.GNN(
+ self._num_outputs,
+ dropout_rate=dropout_rate,
+ activation_fn_name=self.activation_fn_name,
+ hidden_dims=self.hidden_dims,
+ latent_dim=self.latent_dim,
+ num_message_passing_steps=self.num_message_passing_steps)
init_fn = jax.jit(functools.partial(self._model.init, train=False))
fake_batch = jraph.GraphsTuple(
n_node=jnp.asarray([1]),
diff --git a/algoperf/workloads/wmt/wmt_jax/workload.py b/algoperf/workloads/wmt/wmt_jax/workload.py
index 240ad2c11..b1f1e78a8 100644
--- a/algoperf/workloads/wmt/wmt_jax/workload.py
+++ b/algoperf/workloads/wmt/wmt_jax/workload.py
@@ -209,10 +209,7 @@ def translate_and_calculate_bleu(self,
def init_model_fn(
self,
rng: spec.RandomState,
- dropout_rate: Optional[float] = 0.0,
- aux_dropout_rate: Optional[float] = 0.0) -> spec.ModelInitState:
- """aux_dropout_rate is used as attention_dropout_rate."""
-
+ dropout_rate: Optional[float] = 0.0) -> spec.ModelInitState:
init_fake_batch_size = 2
input_shape = (init_fake_batch_size, 256)
target_shape = (init_fake_batch_size, 256)
@@ -224,13 +221,20 @@ def init_model_fn(
else:
raise ValueError(f'Unknown activation function {self.activation}.')
+ if dropout_rate is None:
+ model_config = models.TransformerConfig(
+ pre_ln=self.pre_ln,
+ attention_temp=self.attention_temp,
+ activation=activation,
+ glu=self.glu)
+ else:
model_config = models.TransformerConfig(
- dropout_rate=dropout_rate,
- attention_dropout_rate=aux_dropout_rate,
- pre_ln=self.pre_ln,
- attention_temp=self.attention_temp,
- activation=activation,
- glu=self.glu)
+ dropout_rate=dropout_rate,
+ attention_dropout_rate=dropout_rate,
+ pre_ln=self.pre_ln,
+ attention_temp=self.attention_temp,
+ activation=activation,
+ glu=self.glu)
self._train_model = models.Transformer(model_config)
eval_config = replace(model_config, deterministic=True)
self._eval_model = models.Transformer(eval_config)
From cbd065b490e1e57212d6b0112715ee73fcdf1a67 Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Sat, 31 May 2025 01:30:22 +0000
Subject: [PATCH 013/123] pipe dropout through model_fn
---
.../criteo1tb/criteo1tb_jax/workload.py | 4 +-
.../workloads/fastmri/fastmri_jax/workload.py | 14 ++--
.../imagenet_vit/imagenet_jax/workload.py | 6 +-
.../librispeech_jax/models.py | 72 +++++++++----------
.../librispeech_jax/workload.py | 6 +-
.../librispeech_jax/workload.py | 2 +-
algoperf/workloads/ogbg/ogbg_jax/workload.py | 6 +-
7 files changed, 60 insertions(+), 50 deletions(-)
diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py
index e3864643b..101e02c15 100644
--- a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py
+++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py
@@ -126,7 +126,8 @@ def model_fn(
model_state: spec.ModelAuxiliaryState,
mode: spec.ForwardPassMode,
rng: spec.RandomState,
- update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ update_batch_norm: bool,
+ dropout_rate: float = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del update_batch_norm
inputs = augmented_and_preprocessed_input_batch['inputs']
@@ -134,6 +135,7 @@ def model_fn(
apply_kwargs = {'train': train}
if train:
apply_kwargs['rngs'] = {'dropout': rng}
+ apply_kwargs['dropout_rate'] = dropout_rate
logits_batch = self._model.apply({'params': params}, inputs, **apply_kwargs)
return logits_batch, None
diff --git a/algoperf/workloads/fastmri/fastmri_jax/workload.py b/algoperf/workloads/fastmri/fastmri_jax/workload.py
index 439b8d055..3d891cf8f 100644
--- a/algoperf/workloads/fastmri/fastmri_jax/workload.py
+++ b/algoperf/workloads/fastmri/fastmri_jax/workload.py
@@ -60,14 +60,18 @@ def model_fn(
model_state: spec.ModelAuxiliaryState,
mode: spec.ForwardPassMode,
rng: spec.RandomState,
- update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ update_batch_norm: bool,
+ dropout_rate: float = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del update_batch_norm
train = mode == spec.ForwardPassMode.TRAIN
- logits = self._model.apply({'params': params},
- augmented_and_preprocessed_input_batch['inputs'],
- rngs={'dropout': rng},
- train=train)
+
+ if train:
+ logits = self._model.apply({'params': params},
+ augmented_and_preprocessed_input_batch['inputs'],
+ rngs={'dropout': rng},
+ train=train,
+ dropout_rate=dropout_rate)
return logits, None
# Does NOT apply regularization, which is left to the submitter to do in
diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py
index 5107ed993..89355ac6e 100644
--- a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py
+++ b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py
@@ -66,14 +66,16 @@ def model_fn(
model_state: spec.ModelAuxiliaryState,
mode: spec.ForwardPassMode,
rng: spec.RandomState,
- update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ update_batch_norm: bool,
+ dropout_rate: float = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del update_batch_norm
train = mode == spec.ForwardPassMode.TRAIN
logits = self._model.apply({'params': params},
augmented_and_preprocessed_input_batch['inputs'],
rngs={'dropout': rng},
- train=train)
+ train=train,
+ dropout_rate=dropout_rate)
return logits, None
def _eval_model_on_split(self,
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py
index 29c349e11..2ca0fffdc 100644
--- a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py
+++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py
@@ -37,7 +37,7 @@ class ConformerConfig:
encoder_dim: int = 512
num_attention_heads: int = 8
num_encoder_layers: int = 4
- attention_dropout_rate: float = 0.0
+ dropout_rate: float = 0.1
attention_residual_dropout_rate: Optional[float] = 0.0
conv_residual_dropout_rate: Optional[float] = 0.0
feed_forward_dropout_rate: float = 0.0
@@ -51,8 +51,6 @@ class ConformerConfig:
time_mask_max_ratio: float = 0.05
time_masks_per_frame: float = 0.0
use_dynamic_time_mask_max_frames: bool = True
- # If None, defaults to 0.1.
- input_dropout_rate: Optional[float] = 0.1
batch_norm_momentum: float = 0.999
batch_norm_epsilon: float = 0.001
use_specaug: bool = True
@@ -98,10 +96,12 @@ class Subsample(nn.Module):
input_dropout_rate: dropout rate for inputs.
"""
encoder_dim: int = 0
- input_dropout_rate: float = 0.0
+ dropout_rate: float = 0.0
@nn.compact
- def __call__(self, inputs, input_paddings, train):
+ def __call__(self, inputs, input_paddings, train, dropout_rate=None):
+ if dropout_rate is None:
+ dropout_rate = self.dropout_rate
output_paddings = input_paddings
outputs = jnp.expand_dims(inputs, axis=-1)
@@ -128,8 +128,8 @@ def __call__(self, inputs, input_paddings, train):
seq_length=outputs.shape[1])
outputs = Dropout(
- rate=self.input_dropout_rate, deterministic=not train)(
- outputs)
+ rate=dropout_rate, deterministic=not train)(
+ outputs, rate=dropout_rate)
return outputs, output_paddings
@@ -196,9 +196,10 @@ class FeedForwardModule(nn.Module):
config: ConformerConfig
@nn.compact
- def __call__(self, inputs, padding_mask=None, train=False):
+ def __call__(self, inputs, padding_mask=None, train=False, dropout_rate=dropout_rate):
config = self.config
-
+ if dropout_rate is None:
+ dropout_rate = config.dropout_rate
inputs = LayerNorm(dim=config.encoder_dim)(inputs)
inputs = nn.Dense(
@@ -387,8 +388,11 @@ class MultiHeadedSelfAttention(nn.Module):
config: ConformerConfig = None
@nn.compact
- def __call__(self, inputs, paddings, train):
+ def __call__(self, inputs, paddings, train, dropout_rate=None):
config = self.config
+ if dropout_rate is None:
+ dropout_rate = config.dropout_rate
+
mask_paddings = 1 - paddings
attention_mask = nn.make_attention_mask(
mask_paddings > 0, mask_paddings > 0, dtype=jnp.float32)
@@ -410,13 +414,9 @@ def __call__(self, inputs, paddings, train):
deterministic=not train)(
inputs_q=inputs, mask=attention_mask)
- if config.attention_residual_dropout_rate is None:
- attention_residual_dropout_rate = 0.1
- else:
- attention_residual_dropout_rate = config.attention_residual_dropout_rate
result = Dropout(
- rate=attention_residual_dropout_rate, deterministic=not train)(
- result)
+ rate=dropout_rate, deterministic=not train)(
+ result, rate=dropout_rate)
return result
@@ -526,8 +526,11 @@ def __call__(self,
input_paddings,
train,
update_batch_norm,
- use_running_average_bn):
+ use_running_average_bn,
+ dropout_rate=None):
config = self.config
+ if dropout_rate is None:
+ dropout_rate = config.dropout_rate
inputs = LayerNorm(dim=config.encoder_dim)(inputs)
input_gated1 = nn.Dense(
@@ -572,13 +575,9 @@ def __call__(self,
config.encoder_dim, kernel_init=nn.initializers.xavier_uniform())(
inputs)
- if config.conv_residual_dropout_rate is None:
- conv_residual_dropout_rate = 0.0
- else:
- conv_residual_dropout_rate = config.conv_residual_dropout_rate
inputs = Dropout(
- rate=conv_residual_dropout_rate, deterministic=not train)(
- inputs)
+ rate=dropout_rate, deterministic=not train)(
+ inputs, rate=dropout_rate)
return inputs
@@ -603,26 +602,28 @@ def __call__(self,
input_paddings,
train,
update_batch_norm,
- use_running_average):
+ use_running_average,
+ dropout_rate=None):
config = self.config
padding_mask = jnp.expand_dims(1 - input_paddings, -1)
inputs = inputs + 0.5 * FeedForwardModule(config=self.config)(
- inputs, padding_mask, train)
+ inputs, padding_mask, train, dropout_rate)
inputs = inputs + MultiHeadedSelfAttention(config=self.config)(
- inputs, input_paddings, train)
+ inputs, input_paddings, train, dropout_rate=dropout_rate)
inputs = inputs + \
ConvolutionBlock(config)(inputs,
input_paddings,
train,
update_batch_norm,
- use_running_average
+ use_running_average,
+ dropout_rate
)
inputs = inputs + 0.5 * FeedForwardModule(config=self.config)(
- inputs, padding_mask, train)
+ inputs, padding_mask, train, dropout_rate)
if config.use_post_layer_norm:
inputs = LayerNorm(dim=config.encoder_dim)(inputs)
@@ -656,7 +657,8 @@ def __call__(self,
input_paddings,
train,
update_batch_norm: Optional[bool] = None,
- use_running_average_bn: Optional[bool] = None):
+ use_running_average_bn: Optional[bool] = None,
+ dropout_rate: Optional[float] = None):
config = self.config
outputs = inputs
@@ -681,15 +683,10 @@ def __call__(self,
if train and config.use_specaug:
outputs, output_paddings = self.specaug(outputs, output_paddings)
- # Subsample input by a factor of 4 by performing strided convolutions.
- if config.input_dropout_rate is None:
- input_dropout_rate = 0.1
- else:
- input_dropout_rate = config.input_dropout_rate
outputs, output_paddings = Subsample(
encoder_dim=config.encoder_dim,
- input_dropout_rate=input_dropout_rate)(
- outputs, output_paddings, train)
+ dropout_rate=dropout_rate)(
+ outputs, output_paddings, train, dropout_rate=dropout_rate)
# Run the conformer encoder layers.
for _ in range(config.num_encoder_layers):
@@ -697,7 +694,8 @@ def __call__(self,
output_paddings,
train,
update_batch_norm,
- use_running_average_bn)
+ use_running_average_bn,
+ dropout_rate)
outputs = LayerNorm(config.encoder_dim)(outputs)
# Run the decoder which in this case is a trivial projection layer.
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py
index a54f52c04..2e082cf07 100644
--- a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py
+++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py
@@ -116,7 +116,8 @@ def model_fn(
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool,
- use_running_average_bn: Optional[bool] = None
+ use_running_average_bn: Optional[bool] = None,
+ dropout_rate: Optional[float] = None,
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
variables = {'params': params, **model_state}
inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs']
@@ -129,7 +130,8 @@ def model_fn(
train=True,
rngs={'dropout' : rng},
mutable=['batch_stats'],
- use_running_average_bn=use_running_average_bn)
+ use_running_average_bn=use_running_average_bn,
+ dropout_rate=dropout_rate)
return (logits, logit_paddings), new_model_state
else:
logits, logit_paddings = self._model.apply(
diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py
index 3c9a96f99..825b470db 100644
--- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py
+++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py
@@ -70,7 +70,7 @@ def model_fn(
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool,
- use_running_average_bn: Optional[bool] = None
+ use_running_average_bn: Optional[bool] = None,
dropout_rate: Optional[bool] = None
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
variables = {'params': params, **model_state}
diff --git a/algoperf/workloads/ogbg/ogbg_jax/workload.py b/algoperf/workloads/ogbg/ogbg_jax/workload.py
index f7de3f982..3becd5599 100644
--- a/algoperf/workloads/ogbg/ogbg_jax/workload.py
+++ b/algoperf/workloads/ogbg/ogbg_jax/workload.py
@@ -63,7 +63,8 @@ def model_fn(
model_state: spec.ModelAuxiliaryState,
mode: spec.ForwardPassMode,
rng: spec.RandomState,
- update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ update_batch_norm: bool,
+ dropout_rate: Optional[float]) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
"""Get predicted logits from the network for input graphs."""
del update_batch_norm # No BN in the GNN model.
if model_state is not None:
@@ -74,7 +75,8 @@ def model_fn(
logits = self._model.apply({'params': params},
augmented_and_preprocessed_input_batch['inputs'],
rngs={'dropout': rng},
- train=train)
+ train=train,
+ dropout_rate=dropout_rate)
return logits, None
def _binary_cross_entropy_with_mask(
From 31babfd9da6e2ad55156f55aeeb1c9cf10d88edc Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Wed, 4 Jun 2025 19:05:10 +0000
Subject: [PATCH 014/123] fix syntax
---
algoperf/workloads/wmt/wmt_jax/workload.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/algoperf/workloads/wmt/wmt_jax/workload.py b/algoperf/workloads/wmt/wmt_jax/workload.py
index b1f1e78a8..367c062cb 100644
--- a/algoperf/workloads/wmt/wmt_jax/workload.py
+++ b/algoperf/workloads/wmt/wmt_jax/workload.py
@@ -228,7 +228,7 @@ def init_model_fn(
activation=activation,
glu=self.glu)
else:
- model_config = models.TransformerConfig(
+ model_config = models.TransformerConfig(
dropout_rate=dropout_rate,
attention_dropout_rate=dropout_rate,
pre_ln=self.pre_ln,
From 95d67db14e4c0d68e4868b55240c2211b8b039af Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Wed, 4 Jun 2025 20:45:08 +0000
Subject: [PATCH 015/123] dropout changes wmt jax
---
algoperf/workloads/wmt/wmt_jax/models.py | 67 +++++++++-------------
algoperf/workloads/wmt/wmt_jax/workload.py | 6 +-
2 files changed, 30 insertions(+), 43 deletions(-)
diff --git a/algoperf/workloads/wmt/wmt_jax/models.py b/algoperf/workloads/wmt/wmt_jax/models.py
index 54a917a09..3947a1b81 100644
--- a/algoperf/workloads/wmt/wmt_jax/models.py
+++ b/algoperf/workloads/wmt/wmt_jax/models.py
@@ -28,10 +28,7 @@ class TransformerConfig:
max_len: int = 256
activation: Callable = nn.relu
glu: bool = False
- #If None, defaults to 0.1.
dropout_rate: Optional[float] = 0.1
- #If None, defaults to 0.1.
- attention_dropout_rate: Optional[float] = 0.1
attention_temp: float = 1.0
deterministic: bool = False
decode: bool = False
@@ -154,6 +151,9 @@ class MlpBlock(nn.Module):
def __call__(self, inputs, dropout_rate=None):
"""Applies Transformer MlpBlock module."""
cfg = self.config
+ if dropout_rate is None:
+ dropout_rate = cfg.dropout_rate
+
actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim
x = nn.Dense(
cfg.mlp_dim,
@@ -172,12 +172,7 @@ def __call__(self, inputs, dropout_rate=None):
)(
inputs)
x = x * y
- if dropout_rate is None:
- if cfg.dropout_rate is None:
- dropout_rate = 0.1
- else:
- dropout_rate = cfg.dropout_rate
- x = Dropout()(x, rate=dropout_rate, deterministic=cfg.deterministic)
+ x = Dropout(rate=dropout_rate)(x, rate=dropout_rate, deterministic=cfg.deterministic)
output = nn.Dense(
actual_out_dim,
dtype=cfg.dtype,
@@ -185,7 +180,7 @@ def __call__(self, inputs, dropout_rate=None):
bias_init=cfg.bias_init,
)(
x)
- output = Dropout()(
+ output = Dropout(rate=dropout_rate)(
output, rate=dropout_rate, deterministic=cfg.deterministic)
return output
@@ -211,16 +206,14 @@ def __call__(self, inputs, encoder_mask=None, dropout_rate=None):
output after transformer encoder block.
"""
cfg = self.config
+ if dropout_rate is None:
+ dropout_rate = cfg.dropout_rate
+
pre_ln = cfg.pre_ln
# Attention block.
assert inputs.ndim == 3
x = nn.LayerNorm(dtype=cfg.dtype)(inputs) if pre_ln else inputs
- if dropout_rate is None:
- if cfg.attention_dropout_rate is None:
- dropout_rate = 0.1
- else:
- dropout_rate = cfg.dropout_rate
x = nn.MultiHeadDotProductAttention(
num_heads=cfg.num_heads,
dtype=cfg.dtype,
@@ -233,7 +226,7 @@ def __call__(self, inputs, encoder_mask=None, dropout_rate=None):
deterministic=cfg.deterministic,
)(cfg.attention_temp * x, x, mask=encoder_mask)
- x = Dropout()(x, deterministic=cfg.deterministic, rate=dropout_rate)
+ x = Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic, rate=dropout_rate)
x = x + inputs
if not pre_ln:
x = nn.LayerNorm(dtype=cfg.dtype)(x)
@@ -275,17 +268,15 @@ def __call__(
output after transformer encoder-decoder block.
"""
cfg = self.config
+ if dropout_rate is None:
+ dropout_rate = cfg.dropout_rate
+
pre_ln = cfg.pre_ln
# Decoder block.
assert targets.ndim == 3
x = nn.LayerNorm(dtype=cfg.dtype)(targets) if pre_ln else targets
- if dropout_rate is None:
- if cfg.attention_dropout_rate is None:
- dropout_rate = 0.1
- else:
- dropout_rate = cfg.dropout_rate
x = nn.MultiHeadDotProductAttention(
num_heads=cfg.num_heads,
dtype=cfg.dtype,
@@ -298,11 +289,8 @@ def __call__(
deterministic=cfg.deterministic,
decode=cfg.decode,
)(cfg.attention_temp * x, x, mask=decoder_mask)
- if cfg.dropout_rate is None:
- dropout_rate = 0.1
- else:
- dropout_rate = cfg.dropout_rate
- x = Dropout()(x, deterministic=cfg.deterministic, rate=dropout_rate)
+
+ x = Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic, rate=dropout_rate)
x = x + targets
if not pre_ln:
x = nn.LayerNorm(dtype=cfg.dtype)(x)
@@ -321,7 +309,7 @@ def __call__(
deterministic=cfg.deterministic,
)(cfg.attention_temp * y, encoded, mask=encoder_decoder_mask)
- y = Dropout()(y, deterministic=cfg.deterministic, rate=dropout_rate)
+ y = Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic, rate=dropout_rate)
y = y + x
if not pre_ln:
y = nn.LayerNorm(dtype=cfg.dtype)(y)
@@ -361,6 +349,9 @@ def __call__(self,
output of a transformer encoder.
"""
cfg = self.config
+ if dropout_rate is None:
+ dropout_rate = cfg.dropout_rate
+
assert inputs.ndim == 2 # (batch, len)
# Input Embedding
@@ -377,12 +368,7 @@ def __call__(self,
x = AddPositionEmbs(
config=cfg, decode=False, name="posembed_input")(
x, inputs_positions=inputs_positions)
- if dropout_rate is None:
- if cfg.dropout_rate is None:
- dropout_rate = 0.1
- else:
- dropout_rate = cfg.dropout_rate
- x = Dropout()(x, deterministic=cfg.deterministic, rate=dropout_rate)
+ x = Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic, rate=dropout_rate)
x = x.astype(cfg.dtype)
@@ -432,6 +418,8 @@ def __call__(
output of a transformer decoder.
"""
cfg = self.config
+ if dropout_rate is None:
+ dropout_rate = cfg.dropout_rate
assert encoded.ndim == 3 # (batch, len, depth)
assert targets.ndim == 2 # (batch, len)
@@ -453,12 +441,7 @@ def __call__(
y = AddPositionEmbs(
config=cfg, decode=cfg.decode, name="posembed_output")(
y, inputs_positions=targets_positions)
- if dropout_rate is None:
- if cfg.dropout_rate is None:
- dropout_rate = 0.1
- else:
- dropout_rate = cfg.dropout_rate
- y = Dropout()(y, deterministic=cfg.deterministic, rate=dropout_rate)
+ y = Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic, rate=dropout_rate)
y = y.astype(cfg.dtype)
@@ -549,7 +532,8 @@ def decode(
targets,
targets_positions=None,
inputs_segmentation=None,
- targets_segmentation=None):
+ targets_segmentation=None,
+ dropout_rate=None):
"""Applies Transformer decoder-branch on encoded-input and target.
Args:
@@ -598,7 +582,8 @@ def decode(
targets,
targets_positions=targets_positions,
decoder_mask=decoder_mask,
- encoder_decoder_mask=encoder_decoder_mask)
+ encoder_decoder_mask=encoder_decoder_mask,
+ dropout_rate=droput_rate)
return logits.astype(self.config.dtype)
def __call__(self,
diff --git a/algoperf/workloads/wmt/wmt_jax/workload.py b/algoperf/workloads/wmt/wmt_jax/workload.py
index 367c062cb..193732640 100644
--- a/algoperf/workloads/wmt/wmt_jax/workload.py
+++ b/algoperf/workloads/wmt/wmt_jax/workload.py
@@ -259,7 +259,8 @@ def model_fn(
model_state: spec.ModelAuxiliaryState,
mode: spec.ForwardPassMode,
rng: spec.RandomState,
- update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ update_batch_norm: bool,
+ dropout_rate: Optional[float] = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del update_batch_norm
@@ -286,7 +287,8 @@ def model_fn(
targets_positions=targets_positions,
inputs_segmentation=inputs_segmentations,
targets_segmentation=targets_segmentations,
- rngs={'dropout': rng})
+ rngs={'dropout': rng},
+ dropout_rate=None)
return logits_batch, None
def _normalize_eval_metrics(
From 2c96b884eba3cd8500c8d3fd1de6feb28194fbe3 Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Wed, 4 Jun 2025 20:49:15 +0000
Subject: [PATCH 016/123] modify dockerfile
---
docker/Dockerfile | 18 ++++++++----------
1 file changed, 8 insertions(+), 10 deletions(-)
diff --git a/docker/Dockerfile b/docker/Dockerfile
index 76bc5cfe0..72e3a810f 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -5,7 +5,7 @@
# docker build -t --build-arg framework=pytorch
# To build Docker image
-FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu20.04
+FROM nvidia/cuda:12.9.0-cudnn-devel-ubuntu20.04
# Installing machine packages
RUN echo "Setting up machine"
@@ -23,8 +23,8 @@ RUN apt-get update && apt-get install -y \
libreadline-dev \
libffi-dev \
curl \
- libbz2-dev \
liblzma-dev \
+ libbz2-dev \
vim
# Download and install Python 3.11
@@ -56,8 +56,6 @@ RUN echo "Setting up directories for data and experiment_runs"
RUN mkdir -p data/
RUN mkdir -p experiment_runs/
-RUN pip install --upgrade pip
-
# Install Algorithmic efficiency repo
RUN pip install --upgrade pip
@@ -71,18 +69,18 @@ RUN cd /algorithmic-efficiency && git checkout $branch
RUN if [ "$framework" = "jax" ] ; then \
echo "Installing Jax GPU" \
&& cd /algorithmic-efficiency \
- && pip install -e '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html' \
- && pip install -e '.[pytorch_cpu]' -f 'https://download.pytorch.org/whl/torch_stable.html'; \
+ && pip install -e '.[pytorch_cpu]' -f https://download.pytorch.org/whl/cpu \
+ && pip install -e '.[jax_gpu]' -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --pre ; \
elif [ "$framework" = "pytorch" ] ; then \
echo "Installing Pytorch GPU" \
&& cd /algorithmic-efficiency \
- && pip install -e '.[jax_cpu]' \
- && pip install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/cu121'; \
+ && pip install -e '.[pytorch_gpu]' \
+ && pip install -e '.[jax_cpu]'; \
elif [ "$framework" = "both" ] ; then \
echo "Installing Jax GPU and Pytorch GPU" \
&& cd /algorithmic-efficiency \
- && pip install -e '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html' \
- && pip install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/cu121'; \
+ && pip install -e '.[pytorch_gpu]' \
+ && pip install -e '.[jax_gpu]' -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --pre ; \
else \
echo "Invalid build-arg $framework: framework should be either jax, pytorch or both." >&2 \
&& exit 1 ; \
From 54786a6594bf70388e8791aab2b35c78b7cbf028 Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Wed, 4 Jun 2025 20:49:52 +0000
Subject: [PATCH 017/123] modify docker build script
---
docker/build_docker_images.sh | 23 ++++++++++++++++++-----
1 file changed, 18 insertions(+), 5 deletions(-)
diff --git a/docker/build_docker_images.sh b/docker/build_docker_images.sh
index 645b81955..6b5e67ceb 100644
--- a/docker/build_docker_images.sh
+++ b/docker/build_docker_images.sh
@@ -1,27 +1,40 @@
#!/bin/bash
# Bash script to build and push dev docker images to artifact repo
# Usage:
-# bash build_docker_images.sh -b
+# bash build_docker_images.sh -b -f
# Make program exit with non-zero exit code if any command fails.
set -e
-while getopts b: flag
+while getopts "b:p:f:" flag;
do
case "${flag}" in
b) GIT_BRANCH=${OPTARG};;
+ p) PROJECT=${OPTARG};;
+ f) FRAMEWORK=${OPTARG};;
esac
done
# Artifact repostiory
-ARTIFACT_REPO="europe-west-4-docker.pkg.dev/mlcommons-algoperf/algoperf-docker-repo"
+if [ "$PROJECT" = "mlcommons-algoperf" ]; then
+ ARTIFACT_REPO="europe-west-4-docker.pkg.dev/mlcommons-algoperf/algoperf-docker-repo"
+else
+ ARTIFACT_REPO="us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo"
+fi
-if [[ -z ${GIT_BRANCH+x} ]]
+if [[ -z ${GIT_BRANCH+x} ]];
then
GIT_BRANCH='main' # Set default argument
fi
-for FRAMEWORK in "jax" "pytorch" "both"
+FRAMEWORKS=( "jax" "pythorch" "both" )
+
+if [[ -n "$FRAMEWORK" ]];
+then
+ FRAMEWORKS=("$FRAMEWORK")
+fi
+
+for FRAMEWORK in "${FRAMEWORKS[@]}";
do
IMAGE_NAME="algoperf_${FRAMEWORK}_${GIT_BRANCH}"
DOCKER_BUILD_COMMAND="docker build --no-cache -t $IMAGE_NAME . --build-arg framework=$FRAMEWORK --build-arg branch=$GIT_BRANCH"
From 9c189adc9b11948ab8e605f75da17187ccf35f3a Mon Sep 17 00:00:00 2001
From: David Tweedle
Date: Thu, 5 Jun 2025 16:07:06 -0400
Subject: [PATCH 018/123] Update metrics.py - fix for ogbg pytorch
It seems that the problem affecting the pytorch ogbg workloads (but only if they run for some length of time) has to do with jax/xla cpu compilation of the metrics computation. By converting the jax arrays to numpy, hopefully this can be avoided. The next step is to test on schedule free and shampoo, which I hope to do very soon.
---
algoperf/workloads/ogbg/metrics.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/algoperf/workloads/ogbg/metrics.py b/algoperf/workloads/ogbg/metrics.py
index 55f83d905..c2db383b5 100644
--- a/algoperf/workloads/ogbg/metrics.py
+++ b/algoperf/workloads/ogbg/metrics.py
@@ -40,7 +40,7 @@ def compute(self):
if USE_PYTORCH_DDP:
# Sync labels, logits, and masks across devices.
- all_values = [labels, logits, mask]
+ all_values = [np.array(labels), np.array(logits), np.array(mask)]
for idx, array in enumerate(all_values):
tensor = torch.as_tensor(array, device=DEVICE)
# Assumes that the tensors on all devices have the same shape.
@@ -51,7 +51,7 @@ def compute(self):
mask = mask.astype(bool)
- probs = jax.nn.sigmoid(logits)
+ probs = 1 / (1 + np.exp(-logits))
num_tasks = labels.shape[1]
average_precisions = np.full(num_tasks, np.nan)
From 246d68ee96d05d9b4b463dd08ae36e1715f6b3bb Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Thu, 5 Jun 2025 23:38:41 +0000
Subject: [PATCH 019/123] fsmall fixes
---
algoperf/workloads/fastmri/fastmri_jax/models.py | 2 +-
algoperf/workloads/imagenet_vit/imagenet_jax/models.py | 2 +-
.../workloads/librispeech_conformer/librispeech_jax/models.py | 2 +-
3 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/algoperf/workloads/fastmri/fastmri_jax/models.py b/algoperf/workloads/fastmri/fastmri_jax/models.py
index 7ecca2add..b04510297 100644
--- a/algoperf/workloads/fastmri/fastmri_jax/models.py
+++ b/algoperf/workloads/fastmri/fastmri_jax/models.py
@@ -139,9 +139,9 @@ class ConvBlock(nn.Module):
dropout_rate: Dropout probability.
"""
out_channels: int
- dropout_rate: float = 0.0
use_tanh: bool
use_layer_norm: bool
+ dropout_rate: float = 0.0
@nn.compact
def __call__(self, x, train=True, dropout_rate=None):
diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py
index 227f7c297..8ffc0b610 100644
--- a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py
+++ b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py
@@ -162,7 +162,7 @@ class MAPHead(nn.Module):
"""Multihead Attention Pooling."""
mlp_dim: Optional[int] = None # Defaults to 4x input dim
num_heads: int = 12
- dropout_rate: 0.0
+ dropout_rate: float = 0.0
@nn.compact
def __call__(self, x, dropout_rate=None):
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py
index 2ca0fffdc..2d0da15e5 100644
--- a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py
+++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py
@@ -196,7 +196,7 @@ class FeedForwardModule(nn.Module):
config: ConformerConfig
@nn.compact
- def __call__(self, inputs, padding_mask=None, train=False, dropout_rate=dropout_rate):
+ def __call__(self, inputs, padding_mask=None, train=False, dropout_rate=None):
config = self.config
if dropout_rate is None:
dropout_rate = config.dropout_rate
From 0c8dd14d617ff1a642915f34ccd1e504d5a8c0a1 Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Fri, 6 Jun 2025 00:39:59 +0000
Subject: [PATCH 020/123] change docker base image to 12.1.1
---
docker/Dockerfile | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/docker/Dockerfile b/docker/Dockerfile
index 72e3a810f..f1fc99550 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -5,7 +5,7 @@
# docker build -t --build-arg framework=pytorch
# To build Docker image
-FROM nvidia/cuda:12.9.0-cudnn-devel-ubuntu20.04
+FROM nvidia/cuda:12.1.1-cudnn-devel-ubuntu20.04
# Installing machine packages
RUN echo "Setting up machine"
From a78fa6642aed752010e76e953d11ee4c54bafddd Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Fri, 6 Jun 2025 00:49:04 +0000
Subject: [PATCH 021/123] update base image
---
docker/Dockerfile | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/docker/Dockerfile b/docker/Dockerfile
index f1fc99550..9926b0542 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -5,7 +5,7 @@
# docker build -t --build-arg framework=pytorch
# To build Docker image
-FROM nvidia/cuda:12.1.1-cudnn-devel-ubuntu20.04
+FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu20.04
# Installing machine packages
RUN echo "Setting up machine"
From 4beef49bad7570c40ad1563e9ffc5ab83ba8df90 Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Sat, 7 Jun 2025 04:04:21 +0000
Subject: [PATCH 022/123] add slurm instructions
---
scoring/utils/slurm/README.md | 16 +++
.../utils/slurm/algoperf_slurm_cluster.yaml | 105 ++++++++++++++
.../slurm/algoperf_slurm_pakcer_builder.yaml | 132 ++++++++++++++++++
scoring/utils/slurm/config.json | 106 ++++++++++++++
4 files changed, 359 insertions(+)
create mode 100644 scoring/utils/slurm/algoperf_slurm_cluster.yaml
create mode 100644 scoring/utils/slurm/algoperf_slurm_pakcer_builder.yaml
create mode 100644 scoring/utils/slurm/config.json
diff --git a/scoring/utils/slurm/README.md b/scoring/utils/slurm/README.md
index ffd56fbf3..29df653e4 100644
--- a/scoring/utils/slurm/README.md
+++ b/scoring/utils/slurm/README.md
@@ -1,3 +1,4 @@
+# Launching SLURM jobs with SBATCH
This folder contains a SLURM batch script that can be used to run jobs where each job corresponds to a training run on a given workload, training algorithm, random seed and tuning trial (if on external tuning ruleset).
To launch jobs:
@@ -24,3 +25,18 @@ python3 make_job_config.py \
```
sbatch run_jobs.sh
```
+
+
+# Set up new SLURM cluster
+If you are setting up a new cluster, we recommend using the [HPC toolkit to set up a SLURM cluster](https://cloud.google.com/cluster-toolkit/docs/quickstarts/slurm-cluster).
+To set up the new cluster:
+
+1) [Install the Google Cluster Toolkit](https://github.com/GoogleCloudPlatform/cluster-toolkit?tab=readme-ov-file#quickstart).
+2) Create and deploy a packer node to create a base image for the cluster nodes. See [packer builder terraform blueprint](/scoring/utils/slurm/algoperf_slurm_packer_builder.yaml).
+3) Manually update the image:
+ 1) Create a VM from the Disk image created in the previous step.
+ 2) Install the NVIDIA container toolkit on the VM.
+ 3) Transfer the data from GCP bucket to `/opt/data`.
+ 4) Create a new disk image from the VM.
+4) Create and deploy the cluster. See [cluster terraform blueprint](/scoring/utils/slurm/algoperf_slurm_cluster.yaml).
+
diff --git a/scoring/utils/slurm/algoperf_slurm_cluster.yaml b/scoring/utils/slurm/algoperf_slurm_cluster.yaml
new file mode 100644
index 000000000..e6c35e017
--- /dev/null
+++ b/scoring/utils/slurm/algoperf_slurm_cluster.yaml
@@ -0,0 +1,105 @@
+# Copyright 2024 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+blueprint_name: algoperf-slurm-internal
+
+vars:
+ project_id: training-algorithms-external
+ deployment_name: algoperf-slurm-internal
+ region: europe-west4
+ zone: europe-west4-a
+ disk_size_gb: 3000
+ slurm_cluster_name: algoperf
+ image_name: algoperf-image-data-container-tkt
+
+# Recommended to use GCS backend for Terraform state
+# See https://github.com/GoogleCloudPlatform/hpc-toolkit/tree/main/examples#optional-setting-up-a-remote-terraform-state
+#
+# terraform_backend_defaults:
+# type: gcs
+# configuration:
+# bucket: <>
+
+deployment_groups:
+- group: primary
+ modules:
+ - id: network
+ source: modules/network/vpc
+
+ - id: homefs
+ source: community/modules/file-system/nfs-server
+ use: [network]
+ settings:
+ local_mounts: [/home]
+ disk_size: 3000
+ zone: $(vars.zone)
+
+ - id: script
+ source: modules/scripts/startup-script
+ settings:
+
+- group: cluster
+ modules:
+ - id: v100_nodeset
+ source: community/modules/compute/schedmd-slurm-gcp-v6-nodeset
+ use:
+ - network
+ settings:
+ node_count_dynamic_max: 25 # set to 0 if you want node to live forever
+ region: $(vars.region)
+ zone: $(vars.zone)
+ enable_placement: false
+ bandwidth_tier: gvnic_enabled
+ machine_type: n1-standard-64
+ guest_accelerator:
+ - type: nvidia-tesla-v100
+ count: 8
+ instance_image:
+ project: $(vars.project_id)
+ name: $(vars.image_name)
+ instance_image_custom: true
+
+ - id: v100_partition
+ source: community/modules/compute/schedmd-slurm-gcp-v6-partition
+ use: [v100_nodeset]
+ settings:
+ exclusive: false
+ partition_name: v100
+ is_default: true
+
+ - id: slurm_login
+ source: community/modules/scheduler/schedmd-slurm-gcp-v6-login
+ use: [network]
+ settings:
+ enable_login_public_ips: true
+ instance_image:
+ project: $(vars.project_id)
+ name: $(vars.image_name)
+ instance_image_custom: true
+ zone: $(vars.zone)
+
+ - id: slurm_controller
+ source: community/modules/scheduler/schedmd-slurm-gcp-v6-controller
+ use:
+ - network
+ - v100_partition
+ - homefs
+ - slurm_login
+ settings:
+ enable_controller_public_ips: true
+ instance_image:
+ project: $(vars.project_id)
+ name: $(vars.image_name)
+ instance_image_custom: true
+ region: $(vars.region)
diff --git a/scoring/utils/slurm/algoperf_slurm_pakcer_builder.yaml b/scoring/utils/slurm/algoperf_slurm_pakcer_builder.yaml
new file mode 100644
index 000000000..286728e1d
--- /dev/null
+++ b/scoring/utils/slurm/algoperf_slurm_pakcer_builder.yaml
@@ -0,0 +1,132 @@
+# Copyright 2024 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+---
+
+blueprint_name: algoperf-slurm-packer
+
+vars:
+ project_id: training-algorithms-external
+ deployment_name: algoperf-slurm-packer
+ region: europe-west4
+ zone: europe-west4-a
+ new_image:
+ family: algoperf-image
+ project: $(vars.project_id)
+ disk_size_gb: 3000
+ slurm_cluster_name: algoperf-packer
+
+# Recommended to use GCS backend for Terraform state
+# See https://github.com/GoogleCloudPlatform/hpc-toolkit/tree/main/examples#optional-setting-up-a-remote-terraform-state
+#
+# terraform_backend_defaults:
+# type: gcs
+# configuration:
+# bucket: <>
+
+deployment_groups:
+- group: primary
+ modules:
+ - id: network
+ source: modules/network/vpc
+
+ - id: script
+ source: modules/scripts/startup-script
+ settings:
+ region: $(vars.region)
+ install_ansible: true
+ docker:
+ enabled: true
+ world_writable: true
+ # (TODO) Do I need this?
+ configure_ssh_host_patterns:
+ - 10.0.0.*
+ - 10.1.0.*
+ - 10.2.0.*
+ - 10.3.0.*
+ - 10.4.0.*
+ - 10.5.0.*
+ - 10.6.0.*
+ - 10.7.0.*
+ - $(vars.slurm_cluster_name)*
+ runners:
+ - type: shell
+ destination: install-ml-libraries.sh
+ content: |
+ #!/bin/bash
+ # this script is designed to execute on Slurm images published by SchedMD that:
+ # - are based on Debian distribution of Linux
+ # - have NVIDIA drivers pre-installed
+
+ set -e -o pipefail
+
+ echo "deb https://packages.cloud.google.com/apt google-fast-socket main" > /etc/apt/sources.list.d/google-fast-socket.list
+ apt-get update --allow-releaseinfo-change
+ apt-get install --assume-yes google-fast-socket
+
+ CONDA_BASE=/opt/conda
+
+ if [ -d $CONDA_BASE ]; then
+ exit 0
+ fi
+
+ DL_DIR=\$(mktemp -d)
+ cd $DL_DIR
+ curl -L -O https://github.com/conda-forge/miniforge/releases/download/24.7.1-2/Miniforge3-24.7.1-2-Linux-x86_64.sh
+ HOME=$DL_DIR bash Miniforge3-24.7.1-2-Linux-x86_64.sh -b -p $CONDA_BASE
+ cd -
+ rm -rf $DL_DIR
+ unset DL_DIR
+
+ source $CONDA_BASE/bin/activate base
+ conda init --system
+ conda config --system --set auto_activate_base False
+ # following channel ordering is important! use strict_priority!
+ conda config --system --set channel_priority strict
+ conda update -n base conda --yes
+
+ ### create a virtual environment for tensorflow
+ conda create -n tf python=3.11 --yes
+ conda activate tf
+ pip install tensorflow[and-cuda]==2.18.*
+ pip install tensorrt==10.6.*
+
+ ### create a virtual environment for pytorch
+ conda create -n pytorch python=3.11 --yes
+ conda activate pytorch
+ pip install torch torchvision torchaudio
+
+- group: packer
+ modules:
+ - id: custom-image
+ source: modules/packer/custom-image
+ kind: packer
+ use:
+ - network
+ - script
+ settings:
+ # give VM a public IP to ensure startup script can reach public internet
+ # w/o new VPC
+ omit_external_ip: false
+ source_image_project_id: [schedmd-slurm-public]
+ # see latest in https://github.com/GoogleCloudPlatform/slurm-gcp/blob/master/docs/images.md#published-image-family
+ source_image_family: slurm-gcp-6-8-debian-11
+ # You can find size of source image by using following command
+ # gcloud compute images describe-from-family --project schedmd-slurm-public
+ disk_size: $(vars.disk_size_gb)
+ image_family: $(vars.new_image.family)
+ # building this image does not require a GPU-enabled VM
+ machine_type: c2-standard-16
+ state_timeout: 300m
+ zone: $(vars.zone)
diff --git a/scoring/utils/slurm/config.json b/scoring/utils/slurm/config.json
new file mode 100644
index 000000000..dc19e57f7
--- /dev/null
+++ b/scoring/utils/slurm/config.json
@@ -0,0 +1,106 @@
+{
+ "0": {
+ "framework": "jax",
+ "workload": "imagenet_resnet",
+ "dataset": "imagenet",
+ "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py",
+ "experiment_dir": "$HOME/experiments/jit_switch/study_0",
+ "rng_seed": 411096763,
+ "tuning_ruleset": "external",
+ "num_tuning_trials": 1,
+ "hparam_start_index": 0,
+ "hparam_end_index": 1,
+ "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json"
+ },
+ "1": {
+ "framework": "jax",
+ "workload": "imagenet_vit",
+ "dataset": "imagenet",
+ "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py",
+ "experiment_dir": "$HOME/experiments/jit_switch/study_0",
+ "rng_seed": -1884713130,
+ "tuning_ruleset": "external",
+ "num_tuning_trials": 1,
+ "hparam_start_index": 0,
+ "hparam_end_index": 1,
+ "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json"
+ },
+ "2": {
+ "framework": "jax",
+ "workload": "fastmri",
+ "dataset": "fastmri",
+ "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py",
+ "experiment_dir": "$HOME/experiments/jit_switch/study_0",
+ "rng_seed": -214785144,
+ "tuning_ruleset": "external",
+ "num_tuning_trials": 1,
+ "hparam_start_index": 0,
+ "hparam_end_index": 1,
+ "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json"
+ },
+ "3": {
+ "framework": "jax",
+ "workload": "ogbg",
+ "dataset": "ogbg",
+ "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py",
+ "experiment_dir": "$HOME/experiments/jit_switch/study_0",
+ "rng_seed": -893097833,
+ "tuning_ruleset": "external",
+ "num_tuning_trials": 1,
+ "hparam_start_index": 0,
+ "hparam_end_index": 1,
+ "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json"
+ },
+ "4": {
+ "framework": "jax",
+ "workload": "wmt",
+ "dataset": "wmt",
+ "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py",
+ "experiment_dir": "$HOME/experiments/jit_switch/study_0",
+ "rng_seed": -1244182279,
+ "tuning_ruleset": "external",
+ "num_tuning_trials": 1,
+ "hparam_start_index": 0,
+ "hparam_end_index": 1,
+ "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json"
+ },
+ "5": {
+ "framework": "jax",
+ "workload": "librispeech_deepspeech",
+ "dataset": "librispeech",
+ "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py",
+ "experiment_dir": "$HOME/experiments/jit_switch/study_0",
+ "rng_seed": 1546003634,
+ "tuning_ruleset": "external",
+ "num_tuning_trials": 1,
+ "hparam_start_index": 0,
+ "hparam_end_index": 1,
+ "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json"
+ },
+ "6": {
+ "framework": "jax",
+ "workload": "criteo1tb",
+ "dataset": "criteo1tb",
+ "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py",
+ "experiment_dir": "$HOME/experiments/jit_switch/study_0",
+ "rng_seed": -2062333143,
+ "tuning_ruleset": "external",
+ "num_tuning_trials": 1,
+ "hparam_start_index": 0,
+ "hparam_end_index": 1,
+ "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json"
+ },
+ "7": {
+ "framework": "jax",
+ "workload": "librispeech_conformer",
+ "dataset": "librispeech",
+ "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py",
+ "experiment_dir": "$HOME/experiments/jit_switch/study_0",
+ "rng_seed": 409209730,
+ "tuning_ruleset": "external",
+ "num_tuning_trials": 1,
+ "hparam_start_index": 0,
+ "hparam_end_index": 1,
+ "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json"
+ }
+}
\ No newline at end of file
From 2b782dda88beef216629832e2511fb636ad9f974 Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Sat, 7 Jun 2025 04:06:24 +0000
Subject: [PATCH 023/123] formatting docs
---
scoring/utils/slurm/README.md | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/scoring/utils/slurm/README.md b/scoring/utils/slurm/README.md
index 29df653e4..3e5d85b73 100644
--- a/scoring/utils/slurm/README.md
+++ b/scoring/utils/slurm/README.md
@@ -34,9 +34,9 @@ To set up the new cluster:
1) [Install the Google Cluster Toolkit](https://github.com/GoogleCloudPlatform/cluster-toolkit?tab=readme-ov-file#quickstart).
2) Create and deploy a packer node to create a base image for the cluster nodes. See [packer builder terraform blueprint](/scoring/utils/slurm/algoperf_slurm_packer_builder.yaml).
3) Manually update the image:
- 1) Create a VM from the Disk image created in the previous step.
- 2) Install the NVIDIA container toolkit on the VM.
- 3) Transfer the data from GCP bucket to `/opt/data`.
- 4) Create a new disk image from the VM.
+ 1) Create a VM from the Disk image created in the previous step.
+ 2) Install the NVIDIA container toolkit on the VM.
+ 3) Transfer the data from GCP bucket to `/opt/data`.
+ 4) Create a new disk image from the VM.
4) Create and deploy the cluster. See [cluster terraform blueprint](/scoring/utils/slurm/algoperf_slurm_cluster.yaml).
From dfe4fb47055cfc019dc11b22d8ffa86720d39af7 Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Sat, 7 Jun 2025 04:09:41 +0000
Subject: [PATCH 024/123] update instrucitons
---
scoring/utils/slurm/README.md | 12 +++++++++++-
1 file changed, 11 insertions(+), 1 deletion(-)
diff --git a/scoring/utils/slurm/README.md b/scoring/utils/slurm/README.md
index 3e5d85b73..ed42752dd 100644
--- a/scoring/utils/slurm/README.md
+++ b/scoring/utils/slurm/README.md
@@ -11,7 +11,7 @@ python3 make_job_config.py \
--framework
```
2) Save the config.json in the same directory you will run the sbatch script from.
-3) Check the sbatch script `run_jobs.sh`.
+3) Copy the example sbatch script `run_jobs.sh`.
- Set the task range to the number of tasks in the config.
```
#SBATCH --array=0-119
@@ -21,6 +21,16 @@ python3 make_job_config.py \
#SBATCH --output=experiments///job_%A_%a.out
#SBATCH --error=experiments///job_%A_%a.err
```
+- Update the gcp project information, docker image, config file path and bucket to save the logs to as necessary:
+```
+REPO="us-central1-docker.pkg.dev"
+IMAGE="us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_main"
+y | gcloud auth configure-docker $REPO
+docker pull $IMAGE
+# Job config (ATTENTION: you may want to modify this)
+config_file="$HOME/configs/pmap_job_config.json" # Replace with your config file path
+LOGS_BUCKET="algoperf-runs-internal"
+```
4) Submit a SLURM batch job by running:
```
sbatch run_jobs.sh
From f0019ac4ad04fb8bfe2c6430475a7043dec77ff3 Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Sat, 7 Jun 2025 04:23:18 +0000
Subject: [PATCH 025/123] small fix
---
.../workloads/librispeech_deepspeech/librispeech_jax/models.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py
index 3ad31b532..455366e5e 100644
--- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py
+++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py
@@ -115,7 +115,7 @@ def __call__(self, inputs, output_paddings, train, dropout_rate=None):
else:
input_dropout_rate = config.input_dropout_rate
outputs = Dropout(
- rate=input_dropout_rate, deterministic=not train, rate=dropout_rate)(
+ rate=input_dropout_rate, deterministic=not train)(
outputs, rate=dropout_rate)
return outputs, output_paddings
From 3cb012e919374e204f123145ebbe596bf72b4eac Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Sat, 7 Jun 2025 04:34:36 +0000
Subject: [PATCH 026/123] remove aux_dropout from submission_runner.py
---
submission_runner.py | 5 +----
1 file changed, 1 insertion(+), 4 deletions(-)
diff --git a/submission_runner.py b/submission_runner.py
index 468a04c7c..bb4a8c6cc 100644
--- a/submission_runner.py
+++ b/submission_runner.py
@@ -229,13 +229,10 @@ def train_once(
logging.info('Initializing model.')
with profiler.profile('Initializing model'):
dropout_rate = None
- aux_dropout_rate = None
if hasattr(hyperparameters, 'dropout_rate'):
dropout_rate = hyperparameters.dropout_rate
- if hasattr(hyperparameters, 'aux_dropout_rate'):
- aux_dropout_rate = hyperparameters.aux_dropout_rate
model_params, model_state = workload.init_model_fn(
- model_init_rng, dropout_rate, aux_dropout_rate)
+ model_init_rng, dropout_rate)
if FLAGS.framework == 'pytorch' and FLAGS.torch_compile:
compile_error_workloads = [
'librispeech_conformer',
From fdc956bf62b80f84e634d6717ab9bc12aea33a9e Mon Sep 17 00:00:00 2001
From: David Tweedle
Date: Mon, 9 Jun 2025 10:27:45 -0400
Subject: [PATCH 027/123] Update metrics.py
The problem with torchrun and jax seems to be caused by jax.nn.sigmoid.
---
algoperf/workloads/ogbg/metrics.py | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/algoperf/workloads/ogbg/metrics.py b/algoperf/workloads/ogbg/metrics.py
index c2db383b5..982e1044e 100644
--- a/algoperf/workloads/ogbg/metrics.py
+++ b/algoperf/workloads/ogbg/metrics.py
@@ -37,10 +37,11 @@ def compute(self):
labels = values['labels']
logits = values['logits']
mask = values['mask']
+ sigmoid = jax.nn.sigmoid
if USE_PYTORCH_DDP:
# Sync labels, logits, and masks across devices.
- all_values = [np.array(labels), np.array(logits), np.array(mask)]
+ all_values = [labels, logits, mask]
for idx, array in enumerate(all_values):
tensor = torch.as_tensor(array, device=DEVICE)
# Assumes that the tensors on all devices have the same shape.
@@ -48,10 +49,11 @@ def compute(self):
dist.all_gather(all_tensors, tensor)
all_values[idx] = torch.cat(all_tensors).cpu().numpy()
labels, logits, mask = all_values
+ sigmoid = lambda x: 1 / (1 + np.exp(-x))
mask = mask.astype(bool)
- probs = 1 / (1 + np.exp(-logits))
+ probs = sigmoid(logits)
num_tasks = labels.shape[1]
average_precisions = np.full(num_tasks, np.nan)
From 6c888df9a365be98332575bae74295b23501a7ea Mon Sep 17 00:00:00 2001
From: David Tweedle
Date: Mon, 9 Jun 2025 10:43:29 -0400
Subject: [PATCH 028/123] Update metrics.py
Changed from lambda expression which pylint doesn't like.
---
algoperf/workloads/ogbg/metrics.py | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
diff --git a/algoperf/workloads/ogbg/metrics.py b/algoperf/workloads/ogbg/metrics.py
index 982e1044e..5e1e7c4ef 100644
--- a/algoperf/workloads/ogbg/metrics.py
+++ b/algoperf/workloads/ogbg/metrics.py
@@ -31,6 +31,9 @@ class MeanAveragePrecision(
metrics.CollectingMetric.from_outputs(('logits', 'labels', 'mask'))):
"""Computes the mean average precision (mAP) over different tasks."""
+ def sigmoid_np(x):
+ return 1 / (1 + np.exp(-x))
+
def compute(self):
# Matches the official OGB evaluation scheme for mean average precision.
values = super().compute()
@@ -49,7 +52,7 @@ def compute(self):
dist.all_gather(all_tensors, tensor)
all_values[idx] = torch.cat(all_tensors).cpu().numpy()
labels, logits, mask = all_values
- sigmoid = lambda x: 1 / (1 + np.exp(-x))
+ sigmoid = sigmoid_np
mask = mask.astype(bool)
From e4a55ab1db0a114a3a713c327ab334effcba9d53 Mon Sep 17 00:00:00 2001
From: David Tweedle
Date: Mon, 9 Jun 2025 17:19:11 -0400
Subject: [PATCH 029/123] Update metrics.py
Defined np sigmoid inside use_pytorch_ddp
---
algoperf/workloads/ogbg/metrics.py | 5 ++---
1 file changed, 2 insertions(+), 3 deletions(-)
diff --git a/algoperf/workloads/ogbg/metrics.py b/algoperf/workloads/ogbg/metrics.py
index 5e1e7c4ef..8f342c25d 100644
--- a/algoperf/workloads/ogbg/metrics.py
+++ b/algoperf/workloads/ogbg/metrics.py
@@ -31,9 +31,6 @@ class MeanAveragePrecision(
metrics.CollectingMetric.from_outputs(('logits', 'labels', 'mask'))):
"""Computes the mean average precision (mAP) over different tasks."""
- def sigmoid_np(x):
- return 1 / (1 + np.exp(-x))
-
def compute(self):
# Matches the official OGB evaluation scheme for mean average precision.
values = super().compute()
@@ -52,6 +49,8 @@ def compute(self):
dist.all_gather(all_tensors, tensor)
all_values[idx] = torch.cat(all_tensors).cpu().numpy()
labels, logits, mask = all_values
+ def sigmoid_np(x):
+ return 1 / (1 + np.exp(-x))
sigmoid = sigmoid_np
mask = mask.astype(bool)
From 07f89a2b69d6f7667c83961e0fea4ae228681355 Mon Sep 17 00:00:00 2001
From: David Tweedle
Date: Mon, 9 Jun 2025 17:29:58 -0400
Subject: [PATCH 030/123] Update metrics.py
Added white space before and after sigmoid_np
---
algoperf/workloads/ogbg/metrics.py | 2 ++
1 file changed, 2 insertions(+)
diff --git a/algoperf/workloads/ogbg/metrics.py b/algoperf/workloads/ogbg/metrics.py
index 8f342c25d..19d43aae4 100644
--- a/algoperf/workloads/ogbg/metrics.py
+++ b/algoperf/workloads/ogbg/metrics.py
@@ -49,8 +49,10 @@ def compute(self):
dist.all_gather(all_tensors, tensor)
all_values[idx] = torch.cat(all_tensors).cpu().numpy()
labels, logits, mask = all_values
+
def sigmoid_np(x):
return 1 / (1 + np.exp(-x))
+
sigmoid = sigmoid_np
mask = mask.astype(bool)
From 3e436c771f270be867a6ec973eb6c022a5c6ae69 Mon Sep 17 00:00:00 2001
From: David Tweedle
Date: Mon, 9 Jun 2025 18:19:52 -0400
Subject: [PATCH 031/123] Update metrics.py
Fix white space
---
algoperf/workloads/ogbg/metrics.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/algoperf/workloads/ogbg/metrics.py b/algoperf/workloads/ogbg/metrics.py
index 19d43aae4..ea6041a6c 100644
--- a/algoperf/workloads/ogbg/metrics.py
+++ b/algoperf/workloads/ogbg/metrics.py
@@ -49,10 +49,10 @@ def compute(self):
dist.all_gather(all_tensors, tensor)
all_values[idx] = torch.cat(all_tensors).cpu().numpy()
labels, logits, mask = all_values
-
+
def sigmoid_np(x):
return 1 / (1 + np.exp(-x))
-
+
sigmoid = sigmoid_np
mask = mask.astype(bool)
From b3060762eda8a47c6409bf5804ccca8d53686e0c Mon Sep 17 00:00:00 2001
From: Niccolo-Ajroldi
Date: Tue, 10 Jun 2025 18:42:28 +0200
Subject: [PATCH 032/123] dropout fix criteo, fastmri, vit, conf
---
.../criteo1tb_pytorch/models_dropout.py | 298 ++++++++++
.../models_functional_dropout.py | 308 +++++++++++
algoperf/workloads/dropout_modules.py | 41 ++
.../fastmri/fastmri_pytorch/models_dropout.py | 167 ++++++
.../imagenet_pytorch/models_dropout.py | 395 +++++++++++++
.../librispeech_pytorch/models_dropout.py | 518 ++++++++++++++++++
6 files changed, 1727 insertions(+)
create mode 100644 algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py
create mode 100644 algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_functional_dropout.py
create mode 100644 algoperf/workloads/dropout_modules.py
create mode 100644 algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py
create mode 100644 algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py
create mode 100644 algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py
diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py
new file mode 100644
index 000000000..8042ec31e
--- /dev/null
+++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py
@@ -0,0 +1,298 @@
+"""Pytorch implementation of DLRM-Small."""
+
+import math
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from algoperf.workloads.dropout_modules import CustomDropout, SequentialWithDropout
+
+
+class DenseBlock(nn.Module):
+ """Dense block with optional residual connection.""" ""
+ def __init__(self, module, resnet=False):
+ super().__init__()
+ self.module = module
+ self.resnet = resnet
+
+ def forward(self, x):
+ return self.module(x) + x if self.resnet else self.module(x)
+
+
+class DenseBlockWithDropout(nn.Module):
+ """Dense block with optional residual connection and support for dropout."""
+ def __init__(self, module, resnet=False):
+ super().__init__()
+ self.module = module
+ self.resnet = resnet
+ self._supports_custom_dropout = True
+
+ def forward(self, x, p=None):
+ return self.module(x, p) + x if self.resnet else self.module(x, p)
+
+
+class DotInteract(nn.Module):
+ """Performs feature interaction operation between dense or sparse features."""
+
+ def __init__(self, num_sparse_features):
+ super().__init__()
+ self.triu_indices = torch.triu_indices(num_sparse_features + 1,
+ num_sparse_features + 1)
+
+ def forward(self, dense_features, sparse_features):
+ combined_values = torch.cat((dense_features.unsqueeze(1), sparse_features),
+ dim=1)
+ interactions = torch.bmm(combined_values,
+ torch.transpose(combined_values, 1, 2))
+ interactions_flat = interactions[:,
+ self.triu_indices[0],
+ self.triu_indices[1]]
+ return torch.cat((dense_features, interactions_flat), dim=1)
+
+
+class DLRMResNet(nn.Module):
+ """Define a DLRM-Small model.
+
+ Parameters:
+ vocab_size: vocab size of embedding table.
+ num_dense_features: number of dense features as the bottom mlp input.
+ mlp_bottom_dims: dimensions of dense layers of the bottom mlp.
+ mlp_top_dims: dimensions of dense layers of the top mlp.
+ embed_dim: embedding dimension.
+ """
+
+ def __init__(self,
+ vocab_size,
+ num_dense_features=13,
+ num_sparse_features=26,
+ mlp_bottom_dims=(256, 256, 256),
+ mlp_top_dims=(256, 256, 256, 256, 1),
+ embed_dim=128,
+ dropout_rate=0.0,
+ use_layer_norm=False,
+ embedding_init_multiplier=None):
+ super().__init__()
+ self.vocab_size = torch.tensor(vocab_size, dtype=torch.int32)
+ self.num_dense_features = num_dense_features
+ self.num_sparse_features = num_sparse_features
+ self.mlp_bottom_dims = mlp_bottom_dims
+ self.mlp_top_dims = mlp_top_dims
+ self.embed_dim = embed_dim
+
+ # Ideally, we should use the pooled embedding implementation from
+ # `TorchRec`. However, in order to have identical implementation
+ # with that of Jax, we define a single embedding matrix.
+ num_chunks = 4
+ assert vocab_size % num_chunks == 0
+ self.embedding_table_chucks = []
+ scale = 1.0 / torch.sqrt(self.vocab_size)
+ for i in range(num_chunks):
+ chunk = nn.Parameter(
+ torch.Tensor(self.vocab_size // num_chunks, self.embed_dim))
+ chunk.data.uniform_(0, 1)
+ chunk.data = scale * chunk.data
+ self.register_parameter(f'embedding_chunk_{i}', chunk)
+ self.embedding_table_chucks.append(chunk)
+
+ input_dim = self.num_dense_features
+ bot_mlp_blocks = []
+ for layer_idx, dense_dim in enumerate(self.mlp_bottom_dims):
+ block = []
+ block.append(nn.Linear(input_dim, dense_dim))
+ block.append(nn.ReLU(inplace=True))
+ block = nn.Sequential(*block)
+ if layer_idx > 0:
+ block = DenseBlock(block, resnet=True)
+ else:
+ block = DenseBlock(block)
+ bot_mlp_blocks.append(block)
+ input_dim = dense_dim
+ self.bot_mlp = nn.Sequential(*bot_mlp_blocks)
+
+ for module in self.bot_mlp.modules():
+ if isinstance(module, nn.Linear):
+ limit = math.sqrt(6. / (module.in_features + module.out_features))
+ nn.init.uniform_(module.weight.data, -limit, limit)
+ nn.init.normal_(module.bias.data,
+ 0.,
+ math.sqrt(1. / module.out_features))
+
+ # Number of sparse features = 26
+ fan_in = (26 * self.embed_dim) + self.mlp_bottom_dims[-1]
+ num_layers_top = len(self.mlp_top_dims)
+ mlp_top_blocks = []
+ for layer_idx, fan_out in enumerate(self.mlp_top_dims):
+ block = []
+ block.append(nn.Linear(fan_in, fan_out))
+ if layer_idx < (num_layers_top - 1):
+ block.append(nn.ReLU(inplace=True))
+ if (dropout_rate is not None and dropout_rate > 0.0 and
+ layer_idx == num_layers_top - 2):
+ block.append(CustomDropout()) # (nico)
+ block = SequentialWithDropout(*block) # (nico)
+ if (layer_idx != 0) and (layer_idx != num_layers_top - 1):
+ block = DenseBlockWithDropout(block, resnet=True)
+ else:
+ block = DenseBlockWithDropout(block)
+ mlp_top_blocks.append(block)
+ fan_in = fan_out
+ self.top_mlp = SequentialWithDropout(*mlp_top_blocks) # (nico)
+
+ for module in self.top_mlp.modules():
+ if isinstance(module, nn.Linear):
+ nn.init.normal_(
+ module.weight.data,
+ 0.,
+ math.sqrt(2. / (module.in_features + module.out_features)))
+ nn.init.normal_(module.bias.data,
+ 0.,
+ math.sqrt(1. / module.out_features))
+
+ def forward(self, x, dropout_rate):
+ batch_size = x.shape[0]
+
+ dense_features, sparse_features = torch.split(
+ x, [self.num_dense_features, self.num_sparse_features], 1)
+
+ # Bottom MLP.
+ embedded_dense = self.bot_mlp(dense_features)
+
+ # Sparse feature processing.
+ sparse_features = sparse_features.to(dtype=torch.int32)
+ idx_lookup = torch.reshape(sparse_features, [-1]) % self.vocab_size
+ embedding_table = torch.cat(self.embedding_table_chucks, dim=0)
+ embedded_sparse = embedding_table[idx_lookup]
+ embedded_sparse = torch.reshape(embedded_sparse,
+ [batch_size, 26 * self.embed_dim])
+ top_mlp_input = torch.cat([embedded_dense, embedded_sparse], axis=1)
+
+ # Final MLP.
+ logits = self.top_mlp(top_mlp_input, dropout_rate)
+ return logits
+
+
+class DlrmSmall(nn.Module):
+ """Define a DLRM-Small model.
+
+ Parameters:
+ vocab_size: vocab size of embedding table.
+ num_dense_features: number of dense features as the bottom mlp input.
+ mlp_bottom_dims: dimensions of dense layers of the bottom mlp.
+ mlp_top_dims: dimensions of dense layers of the top mlp.
+ embed_dim: embedding dimension.
+ """
+
+ def __init__(self,
+ vocab_size,
+ num_dense_features=13,
+ num_sparse_features=26,
+ mlp_bottom_dims=(512, 256, 128),
+ mlp_top_dims=(1024, 1024, 512, 256, 1),
+ embed_dim=128,
+ dropout_rate=0.0,
+ use_layer_norm=False,
+ embedding_init_multiplier=None):
+ super().__init__()
+ self.vocab_size = torch.tensor(vocab_size, dtype=torch.int32)
+ self.num_dense_features = num_dense_features
+ self.num_sparse_features = num_sparse_features
+ self.mlp_bottom_dims = mlp_bottom_dims
+ self.mlp_top_dims = mlp_top_dims
+ self.embed_dim = embed_dim
+ self.embedding_init_multiplier = embedding_init_multiplier
+
+ # Ideally, we should use the pooled embedding implementation from
+ # `TorchRec`. However, in order to have identical implementation
+ # with that of Jax, we define a single embedding matrix.
+ num_chunks = 4
+ assert vocab_size % num_chunks == 0
+ self.embedding_table_chucks = []
+
+ if self.embedding_init_multiplier is None:
+ scale = 1.0 / torch.sqrt(self.vocab_size)
+ else:
+ scale = self.embedding_init_multiplier
+
+ for i in range(num_chunks):
+ chunk = nn.Parameter(
+ torch.Tensor(self.vocab_size // num_chunks, self.embed_dim))
+ chunk.data.uniform_(0, 1)
+ chunk.data = scale * chunk.data
+ self.register_parameter(f'embedding_chunk_{i}', chunk)
+ self.embedding_table_chucks.append(chunk)
+
+ input_dim = self.num_dense_features
+ bottom_mlp_layers = []
+ for dense_dim in self.mlp_bottom_dims:
+ bottom_mlp_layers.append(nn.Linear(input_dim, dense_dim))
+ bottom_mlp_layers.append(nn.ReLU(inplace=True))
+ if use_layer_norm:
+ bottom_mlp_layers.append(nn.LayerNorm(dense_dim, eps=1e-6))
+ input_dim = dense_dim
+ self.bot_mlp = nn.Sequential(*bottom_mlp_layers)
+ for module in self.bot_mlp.modules():
+ if isinstance(module, nn.Linear):
+ limit = math.sqrt(6. / (module.in_features + module.out_features))
+ nn.init.uniform_(module.weight.data, -limit, limit)
+ nn.init.normal_(module.bias.data,
+ 0.,
+ math.sqrt(1. / module.out_features))
+
+ self.dot_interact = DotInteract(num_sparse_features=num_sparse_features,)
+
+ # TODO: Write down the formula here instead of the constant.
+ input_dims = 506
+ num_layers_top = len(self.mlp_top_dims)
+ top_mlp_layers = []
+ for layer_idx, fan_out in enumerate(self.mlp_top_dims):
+ fan_in = input_dims if layer_idx == 0 \
+ else self.mlp_top_dims[layer_idx - 1]
+ top_mlp_layers.append(nn.Linear(fan_in, fan_out))
+ if layer_idx < (num_layers_top - 1):
+ top_mlp_layers.append(nn.ReLU(inplace=True))
+ if use_layer_norm:
+ top_mlp_layers.append(nn.LayerNorm(fan_out, eps=1e-6))
+ if (dropout_rate is not None and dropout_rate > 0.0 and
+ layer_idx == num_layers_top - 2):
+ top_mlp_layers.append(CustomDropout()) # (nico)
+ self.top_mlp = SequentialWithDropout(*top_mlp_layers) # (nico)
+ if use_layer_norm:
+ self.embed_ln = nn.LayerNorm(self.embed_dim, eps=1e-6)
+ else:
+ self.embed_ln = None
+ for module in self.top_mlp.modules():
+ if isinstance(module, nn.Linear):
+ nn.init.normal_(
+ module.weight.data,
+ 0.,
+ math.sqrt(2. / (module.in_features + module.out_features)))
+ nn.init.normal_(module.bias.data,
+ 0.,
+ math.sqrt(1. / module.out_features))
+
+ def forward(self, x, dropout_rate):
+ batch_size = x.shape[0]
+
+ dense_features, sparse_features = torch.split(
+ x, [self.num_dense_features, self.num_sparse_features], 1)
+
+ # Bottom MLP.
+ embedded_dense = self.bot_mlp(dense_features)
+
+ # Sparse feature processing.
+ sparse_features = sparse_features.to(dtype=torch.int32)
+ idx_lookup = torch.reshape(sparse_features, [-1]) % self.vocab_size
+ embedding_table = torch.cat(self.embedding_table_chucks, dim=0)
+ embedded_sparse = embedding_table[idx_lookup]
+ embedded_sparse = torch.reshape(embedded_sparse,
+ [batch_size, -1, self.embed_dim])
+ if self.embed_ln:
+ embedded_sparse = self.embed_ln(embedded_sparse)
+ # Dot product interactions.
+ concatenated_dense = self.dot_interact(
+ dense_features=embedded_dense, sparse_features=embedded_sparse)
+
+ # Final MLP.
+ logits = self.top_mlp(concatenated_dense, dropout_rate)
+ return logits
diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_functional_dropout.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_functional_dropout.py
new file mode 100644
index 000000000..346e0e72a
--- /dev/null
+++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_functional_dropout.py
@@ -0,0 +1,308 @@
+"""Pytorch implementation of DLRM-Small."""
+
+import math
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+class DenseBlock(nn.Module):
+ """Dense block with optional residual connection.""" ""
+
+ def __init__(self, module, resnet=False):
+ super().__init__()
+ self.module = module
+ self.resnet = resnet
+
+ def forward(self, x):
+ if self.resnet:
+ return self.module(x) + x
+ else:
+ return self.module(x)
+
+
+class DotInteract(nn.Module):
+ """Performs feature interaction operation between dense or sparse features."""
+
+ def __init__(self, num_sparse_features):
+ super().__init__()
+ self.triu_indices = torch.triu_indices(num_sparse_features + 1,
+ num_sparse_features + 1)
+
+ def forward(self, dense_features, sparse_features):
+ combined_values = torch.cat((dense_features.unsqueeze(1), sparse_features),
+ dim=1)
+ interactions = torch.bmm(combined_values,
+ torch.transpose(combined_values, 1, 2))
+ interactions_flat = interactions[:,
+ self.triu_indices[0],
+ self.triu_indices[1]]
+ return torch.cat((dense_features, interactions_flat), dim=1)
+
+
+class DLRMResNet(nn.Module):
+ """Define a DLRM-Small model.
+
+ Parameters:
+ vocab_size: vocab size of embedding table.
+ num_dense_features: number of dense features as the bottom mlp input.
+ mlp_bottom_dims: dimensions of dense layers of the bottom mlp.
+ mlp_top_dims: dimensions of dense layers of the top mlp.
+ embed_dim: embedding dimension.
+ """
+
+ def __init__(self,
+ vocab_size,
+ num_dense_features=13,
+ num_sparse_features=26,
+ mlp_bottom_dims=(256, 256, 256),
+ mlp_top_dims=(256, 256, 256, 256, 1),
+ embed_dim=128,
+ # dropout_rate=0.0,
+ use_layer_norm=False,
+ embedding_init_multiplier=None):
+ super().__init__()
+ self.vocab_size = torch.tensor(vocab_size, dtype=torch.int32)
+ self.num_dense_features = num_dense_features
+ self.num_sparse_features = num_sparse_features
+ self.mlp_bottom_dims = mlp_bottom_dims
+ self.mlp_top_dims = mlp_top_dims
+ self.embed_dim = embed_dim
+
+ # Ideally, we should use the pooled embedding implementation from
+ # `TorchRec`. However, in order to have identical implementation
+ # with that of Jax, we define a single embedding matrix.
+ num_chunks = 4
+ assert vocab_size % num_chunks == 0
+ self.embedding_table_chucks = []
+ scale = 1.0 / torch.sqrt(self.vocab_size)
+ for i in range(num_chunks):
+ chunk = nn.Parameter(
+ torch.Tensor(self.vocab_size // num_chunks, self.embed_dim))
+ chunk.data.uniform_(0, 1)
+ chunk.data = scale * chunk.data
+ self.register_parameter(f'embedding_chunk_{i}', chunk)
+ self.embedding_table_chucks.append(chunk)
+
+ input_dim = self.num_dense_features
+ bot_mlp_blocks = []
+ for layer_idx, dense_dim in enumerate(self.mlp_bottom_dims):
+ block = []
+ block.append(nn.Linear(input_dim, dense_dim))
+ block.append(nn.ReLU(inplace=True))
+ block = nn.Sequential(*block)
+ if layer_idx > 0:
+ block = DenseBlock(block, resnet=True)
+ else:
+ block = DenseBlock(block)
+ bot_mlp_blocks.append(block)
+ input_dim = dense_dim
+ self.bot_mlp = nn.Sequential(*bot_mlp_blocks)
+
+ for module in self.bot_mlp.modules():
+ if isinstance(module, nn.Linear):
+ limit = math.sqrt(6. / (module.in_features + module.out_features))
+ nn.init.uniform_(module.weight.data, -limit, limit)
+ nn.init.normal_(module.bias.data,
+ 0.,
+ math.sqrt(1. / module.out_features))
+
+ # Number of sparse features = 26
+ fan_in = (26 * self.embed_dim) + self.mlp_bottom_dims[-1]
+ num_layers_top = len(self.mlp_top_dims)
+ mlp_top_blocks = []
+ for layer_idx, fan_out in enumerate(self.mlp_top_dims):
+ block = []
+ block.append(nn.Linear(fan_in, fan_out))
+ if layer_idx < (num_layers_top - 1):
+ block.append(nn.ReLU(inplace=True))
+ # if (dropout_rate is not None and dropout_rate > 0.0 and
+ # layer_idx == num_layers_top - 2):
+ # block.append(nn.Dropout(p=dropout_rate))
+ block = nn.Sequential(*block)
+ if (layer_idx != 0) and (layer_idx != num_layers_top - 1):
+ block = DenseBlock(block, resnet=True)
+ else:
+ block = DenseBlock(block)
+ mlp_top_blocks.append(block)
+ fan_in = fan_out
+ self.top_mlp = nn.Sequential(*mlp_top_blocks)
+
+ for module in self.top_mlp.modules():
+ if isinstance(module, nn.Linear):
+ nn.init.normal_(
+ module.weight.data,
+ 0.,
+ math.sqrt(2. / (module.in_features + module.out_features)))
+ nn.init.normal_(module.bias.data,
+ 0.,
+ math.sqrt(1. / module.out_features))
+
+ def forward(self, x, dropout_rate):
+ batch_size = x.shape[0]
+
+ dense_features, sparse_features = torch.split(
+ x, [self.num_dense_features, self.num_sparse_features], 1)
+
+ # Bottom MLP.
+ embedded_dense = self.bot_mlp(dense_features)
+
+ # Sparse feature processing.
+ sparse_features = sparse_features.to(dtype=torch.int32)
+ idx_lookup = torch.reshape(sparse_features, [-1]) % self.vocab_size
+ embedding_table = torch.cat(self.embedding_table_chucks, dim=0)
+ embedded_sparse = embedding_table[idx_lookup]
+ embedded_sparse = torch.reshape(embedded_sparse,
+ [batch_size, 26 * self.embed_dim])
+ top_mlp_input = torch.cat([embedded_dense, embedded_sparse], axis=1)
+
+ # Final MLP (horrible!!).
+ h = top_mlp_input
+ num_layers_top = len(self.mlp_top_dims)
+ for layer_idx, block in enumerate(self.top_mlp):
+ # block.module is nn.Sequential([...])
+ seq = block.module
+ # 1) linear
+ out = seq[0](h)
+ # 2) ReLU (if present)
+ if layer_idx < (num_layers_top - 1):
+ out = seq[1](out)
+ # 3) functional dropout at penult layer
+ if dropout_rate > 0 and layer_idx == num_layers_top - 2:
+ out = F.dropout(out, dropout_rate, training=self.training)
+ # 4) wrap in residual if needed
+ h = out + h if block.resnet else out
+ return h
+
+
+class DlrmSmall(nn.Module):
+ """Define a DLRM-Small model.
+
+ Parameters:
+ vocab_size: vocab size of embedding table.
+ num_dense_features: number of dense features as the bottom mlp input.
+ mlp_bottom_dims: dimensions of dense layers of the bottom mlp.
+ mlp_top_dims: dimensions of dense layers of the top mlp.
+ embed_dim: embedding dimension.
+ """
+
+ def __init__(self,
+ vocab_size,
+ num_dense_features=13,
+ num_sparse_features=26,
+ mlp_bottom_dims=(512, 256, 128),
+ mlp_top_dims=(1024, 1024, 512, 256, 1),
+ embed_dim=128,
+ # dropout_rate=0.0,
+ use_layer_norm=False,
+ embedding_init_multiplier=None):
+ super().__init__()
+ self.vocab_size = torch.tensor(vocab_size, dtype=torch.int32)
+ self.num_dense_features = num_dense_features
+ self.num_sparse_features = num_sparse_features
+ self.mlp_bottom_dims = mlp_bottom_dims
+ self.mlp_top_dims = mlp_top_dims
+ self.embed_dim = embed_dim
+ self.embedding_init_multiplier = embedding_init_multiplier
+
+ # Ideally, we should use the pooled embedding implementation from
+ # `TorchRec`. However, in order to have identical implementation
+ # with that of Jax, we define a single embedding matrix.
+ num_chunks = 4
+ assert vocab_size % num_chunks == 0
+ self.embedding_table_chucks = []
+
+ if self.embedding_init_multiplier is None:
+ scale = 1.0 / torch.sqrt(self.vocab_size)
+ else:
+ scale = self.embedding_init_multiplier
+
+ for i in range(num_chunks):
+ chunk = nn.Parameter(
+ torch.Tensor(self.vocab_size // num_chunks, self.embed_dim))
+ chunk.data.uniform_(0, 1)
+ chunk.data = scale * chunk.data
+ self.register_parameter(f'embedding_chunk_{i}', chunk)
+ self.embedding_table_chucks.append(chunk)
+
+ input_dim = self.num_dense_features
+ bottom_mlp_layers = []
+ for dense_dim in self.mlp_bottom_dims:
+ bottom_mlp_layers.append(nn.Linear(input_dim, dense_dim))
+ bottom_mlp_layers.append(nn.ReLU(inplace=True))
+ if use_layer_norm:
+ bottom_mlp_layers.append(nn.LayerNorm(dense_dim, eps=1e-6))
+ input_dim = dense_dim
+ self.bot_mlp = nn.Sequential(*bottom_mlp_layers)
+ for module in self.bot_mlp.modules():
+ if isinstance(module, nn.Linear):
+ limit = math.sqrt(6. / (module.in_features + module.out_features))
+ nn.init.uniform_(module.weight.data, -limit, limit)
+ nn.init.normal_(module.bias.data,
+ 0.,
+ math.sqrt(1. / module.out_features))
+
+ self.dot_interact = DotInteract(num_sparse_features=num_sparse_features,)
+
+ # TODO: Write down the formula here instead of the constant.
+ input_dims = 506
+ num_layers_top = len(self.mlp_top_dims)
+ top_mlp_layers = []
+ for layer_idx, fan_out in enumerate(self.mlp_top_dims):
+ fan_in = input_dims if layer_idx == 0 \
+ else self.mlp_top_dims[layer_idx - 1]
+ top_mlp_layers.append(nn.Linear(fan_in, fan_out))
+ if layer_idx < (num_layers_top - 1):
+ top_mlp_layers.append(nn.ReLU(inplace=True))
+ if use_layer_norm:
+ top_mlp_layers.append(nn.LayerNorm(fan_out, eps=1e-6))
+ # if (dropout_rate is not None and dropout_rate > 0.0 and
+ # layer_idx == num_layers_top - 2):
+ # top_mlp_layers.append(nn.Dropout(p=dropout_rate))
+ self.top_mlp = nn.Sequential(*top_mlp_layers)
+ if use_layer_norm:
+ self.embed_ln = nn.LayerNorm(self.embed_dim, eps=1e-6)
+ else:
+ self.embed_ln = None
+ for module in self.top_mlp.modules():
+ if isinstance(module, nn.Linear):
+ nn.init.normal_(
+ module.weight.data,
+ 0.,
+ math.sqrt(2. / (module.in_features + module.out_features)))
+ nn.init.normal_(module.bias.data,
+ 0.,
+ math.sqrt(1. / module.out_features))
+
+ def forward(self, x, dropout_rate):
+ batch_size = x.shape[0]
+
+ dense_features, sparse_features = torch.split(
+ x, [self.num_dense_features, self.num_sparse_features], 1)
+
+ # Bottom MLP.
+ embedded_dense = self.bot_mlp(dense_features)
+
+ # Sparse feature processing.
+ sparse_features = sparse_features.to(dtype=torch.int32)
+ idx_lookup = torch.reshape(sparse_features, [-1]) % self.vocab_size
+ embedding_table = torch.cat(self.embedding_table_chucks, dim=0)
+ embedded_sparse = embedding_table[idx_lookup]
+ embedded_sparse = torch.reshape(embedded_sparse,
+ [batch_size, -1, self.embed_dim])
+ if self.embed_ln:
+ embedded_sparse = self.embed_ln(embedded_sparse)
+ # Dot product interactions.
+ concatenated_dense = self.dot_interact(
+ dense_features=embedded_dense, sparse_features=embedded_sparse)
+
+ # Final MLP: run each layer, and after the penultimate layer do functional dropout
+ h = concatenated_dense
+ N = len(self.top_mlp)
+ for idx, layer in enumerate(self.top_mlp):
+ h = layer(h)
+ # insert dropout exactly where nn.Dropout used to live
+ if dropout_rate > 0 and idx == N - 2:
+ h = F.dropout(h, dropout_rate, training=self.training)
+ return h
diff --git a/algoperf/workloads/dropout_modules.py b/algoperf/workloads/dropout_modules.py
new file mode 100644
index 000000000..3917b75bf
--- /dev/null
+++ b/algoperf/workloads/dropout_modules.py
@@ -0,0 +1,41 @@
+"""Custom classes to support a dynamic modulized dropout, see issue??TODO"""
+
+from torch import Tensor
+from torch import nn
+import torch.nn.functional as F
+
+
+class CustomDropout(nn.Module):
+ """A module around torch.nn.functional.dropout."""
+ def __init__(self):
+ super().__init__()
+ self._supports_custom_dropout = True
+
+ def forward(self, input: Tensor, p: float) -> Tensor:
+ return F.dropout(input, p, training=self.training)
+
+
+class CustomDropout2d(nn.Module):
+ """A module around torch.nn.functional.dropout2d."""
+ def __init__(self):
+ super().__init__()
+ self._supports_custom_dropout = True
+
+ def forward(self, input: Tensor, p: float) -> Tensor:
+ return F.dropout2d(input, p, training=self.training)
+
+
+class SequentialWithDropout(nn.Sequential):
+ """Sequential of modules with dropout."""
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._supports_custom_dropout = True
+
+ def forward(self, x, p):
+ for module in self:
+ # if isinstance(module, (CustomDropout, SequentialWithDropout, DenseBlockWithDropout)):
+ if getattr(module, '_supports_custom_dropout', False): # TODO (nico): improve
+ x = module(x, p)
+ else:
+ x = module(x)
+ return x
diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py b/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py
new file mode 100644
index 000000000..5862f6352
--- /dev/null
+++ b/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py
@@ -0,0 +1,167 @@
+"""U-Net Model.
+
+Adapted from fastMRI:
+https://github.com/facebookresearch/fastMRI/blob/main/fastmri/models/unet.py
+"""
+
+from functools import partial
+from typing import Optional
+
+import torch
+from torch import nn
+from torch import Tensor
+from torch.nn import functional as F
+
+from algoperf import init_utils
+from algoperf.workloads.dropout_modules import CustomDropout2d, SequentialWithDropout
+
+
+
+class UNet(nn.Module):
+ r"""U-Net model from
+ `"U-net: Convolutional networks
+ for biomedical image segmentation"
+ `_.
+ """
+
+ def __init__(self,
+ in_chans: int = 1,
+ out_chans: int = 1,
+ num_channels: int = 32,
+ num_pool_layers: int = 4,
+ use_tanh: bool = False,
+ use_layer_norm: bool = False) -> None:
+ super().__init__()
+
+ self.in_chans = in_chans
+ self.out_chans = out_chans
+ self.num_channels = num_channels
+ self.num_pool_layers = num_pool_layers
+ self.down_sample_layers = nn.ModuleList([
+ ConvBlock(in_chans,
+ num_channels,
+ use_tanh,
+ use_layer_norm)
+ ])
+ ch = num_channels
+ for _ in range(num_pool_layers - 1):
+ self.down_sample_layers.append(
+ ConvBlock(ch, ch * 2, use_tanh, use_layer_norm))
+ ch *= 2
+ self.conv = ConvBlock(ch, ch * 2, use_tanh, use_layer_norm)
+
+ self.up_conv = nn.ModuleList()
+ self.up_transpose_conv = nn.ModuleList()
+
+ for _ in range(num_pool_layers - 1):
+ self.up_transpose_conv.append(
+ TransposeConvBlock(ch * 2, ch, use_tanh, use_layer_norm))
+ self.up_conv.append(
+ ConvBlock(ch * 2, ch, use_tanh, use_layer_norm))
+ ch //= 2
+
+ self.up_transpose_conv.append(
+ TransposeConvBlock(ch * 2, ch, use_tanh, use_layer_norm))
+ self.up_conv.append(
+ SequentialWithDropout(
+ ConvBlock(ch * 2, ch, use_tanh, use_layer_norm),
+ nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1),
+ ))
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
+ init_utils.pytorch_default_init(m)
+
+ def forward(self, x: Tensor, dropout_rate: float) -> Tensor:
+ stack = []
+ output = x
+
+ # apply down-sampling layers
+ for layer in self.down_sample_layers:
+ output = layer(output, dropout_rate)
+ stack.append(output)
+ output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0)
+
+ output = self.conv(output, dropout_rate)
+
+ # apply up-sampling layers
+ for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv):
+ downsample_layer = stack.pop()
+ output = transpose_conv(output)
+
+ # reflect pad on the right/bottom if needed to handle
+ # odd input dimensions
+ padding = [0, 0, 0, 0]
+ if output.shape[-1] != downsample_layer.shape[-1]:
+ padding[1] = 1 # padding right
+ if output.shape[-2] != downsample_layer.shape[-2]:
+ padding[3] = 1 # padding bottom
+ if torch.sum(torch.tensor(padding)) != 0:
+ output = F.pad(output, padding, "reflect")
+
+ output = torch.cat([output, downsample_layer], dim=1)
+ output = conv(output, dropout_rate)
+
+ return output
+
+
+class ConvBlock(nn.Module):
+ # A Convolutional Block that consists of two convolution layers each
+ # followed by instance normalization, LeakyReLU activation and dropout_rate.
+
+ def __init__(self,
+ in_chans: int,
+ out_chans: int,
+ use_tanh: bool,
+ use_layer_norm: bool) -> None:
+ super().__init__()
+ self._supports_custom_dropout = True
+
+ if use_layer_norm:
+ norm_layer = partial(nn.GroupNorm, 1, eps=1e-6)
+ else:
+ norm_layer = nn.InstanceNorm2d
+ if use_tanh:
+ activation_fn = nn.Tanh()
+ else:
+ activation_fn = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+ self.conv_layers = SequentialWithDropout(
+ nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False),
+ norm_layer(out_chans),
+ activation_fn,
+ CustomDropout2d(),
+ nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False),
+ norm_layer(out_chans),
+ activation_fn,
+ CustomDropout2d(),
+ )
+
+ def forward(self, x: Tensor, dropout_rate: float) -> Tensor:
+ return self.conv_layers(x, dropout_rate)
+
+
+class TransposeConvBlock(nn.Module):
+ # A Transpose Convolutional Block that consists of one convolution transpose
+ # layers followed by instance normalization and LeakyReLU activation.
+
+ def __init__(
+ self,
+ in_chans: int,
+ out_chans: int,
+ use_tanh: bool,
+ use_layer_norm: bool,
+ ):
+ super().__init__()
+ if use_tanh:
+ activation_fn = nn.Tanh()
+ else:
+ activation_fn = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+ self.layers = nn.Sequential(
+ nn.ConvTranspose2d(
+ in_chans, out_chans, kernel_size=2, stride=2, bias=False),
+ nn.InstanceNorm2d(out_chans),
+ activation_fn,
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self.layers(x)
diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py
new file mode 100644
index 000000000..f5e315fd7
--- /dev/null
+++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py
@@ -0,0 +1,395 @@
+"""PyTorch implementation of refactored and simplified ViT.
+
+Adapted from:
+https://github.com/huggingface/transformers/tree/main/src/transformers/models/vit
+and https://github.com/lucidrains/vit-pytorch.
+"""
+
+import math
+from typing import Any, Optional, Tuple, Union
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from algoperf import init_utils
+from algoperf import spec
+from algoperf.workloads.wmt.wmt_pytorch.models import MultiheadAttention
+
+
+def posemb_sincos_2d(patches: spec.Tensor, temperature=10_000.) -> spec.Tensor:
+ """Follows the MoCo v3 logic."""
+ _, width, h, w = patches.shape
+ device = patches.device
+ y, x = torch.meshgrid(torch.arange(h, device=device),
+ torch.arange(w, device=device), indexing='ij')
+
+ if width % 4 != 0:
+ raise ValueError('Width must be mult of 4 for sincos posemb.')
+ omega = torch.arange(width // 4, device=device) / (width // 4 - 1)
+ omega = 1. / (temperature**omega)
+ y = y.flatten()[:, None] * omega[None, :]
+ x = x.flatten()[:, None] * omega[None, :]
+ pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
+ return pe[None, :, :]
+
+
+class MlpBlock(nn.Module):
+ """Transformer MLP / feed-forward block."""
+
+ def __init__(
+ self,
+ width: int,
+ mlp_dim: Optional[int] = None, # Defaults to 4x input dim.
+ use_glu: bool = False,
+ dropout_rate: float = 0.0) -> None:
+ super().__init__()
+
+ self.width = width
+ self.mlp_dim = mlp_dim or 4 * width
+ self.use_glu = use_glu
+ self.dropout_rate = dropout_rate
+
+ self.linear1 = nn.Linear(self.width, self.mlp_dim)
+ self.act_fnc = nn.GELU(approximate='tanh')
+
+ if self.use_glu:
+ self.glu_linear = nn.Linear(self.mlp_dim, self.mlp_dim)
+ else:
+ self.glu_linear = None
+
+ self.linear2 = nn.Linear(self.mlp_dim, self.width)
+
+ self.reset_parameters()
+
+ def reset_parameters(self) -> None:
+ for module in self.modules():
+ if isinstance(module, nn.Linear):
+ nn.init.xavier_uniform_(module.weight.data)
+ if module.bias is not None:
+ module.bias.data.normal_(std=1e-6)
+
+ def forward(self, x: spec.Tensor, dropout_rate=None) -> spec.Tensor:
+ if dropout_rate is None:
+ dropout_rate = self.dropout_rate
+
+ x = self.linear1(x)
+ x = self.act_fnc(x)
+
+ if self.use_glu:
+ y = self.glu_linear(x)
+ x = x * y
+
+ x = F.dropout(x, dropout_rate, training=self.training)
+ x = self.linear2(x)
+ return x
+
+
+class SelfAttention(nn.Module):
+ """Self-attention special case of multi-head dot-product attention."""
+
+ def __init__(self,
+ width: int,
+ num_heads: int = 8,
+ dropout_rate: float = 0.0) -> None:
+ super().__init__()
+
+ self.width = width
+ self.num_heads = num_heads
+
+ assert width % num_heads == 0, (
+ 'Memory dimension must be divisible by number of heads.')
+
+ self.head_dim = int(width / num_heads)
+ self.all_head_dim = self.num_heads * self.head_dim
+ self.dropout_rate = dropout_rate
+
+ self.query = nn.Linear(self.width, self.all_head_dim)
+ self.key = nn.Linear(self.width, self.all_head_dim)
+ self.value = nn.Linear(self.width, self.all_head_dim)
+ self.out = nn.Linear(self.width, self.width)
+ self.reset_parameters()
+
+ def reset_parameters(self) -> None:
+ for module in self.modules():
+ if isinstance(module, nn.Linear):
+ nn.init.xavier_uniform_(module.weight.data)
+ if module.bias is not None:
+ nn.init.constant_(module.bias.data, 0.)
+
+ def transpose_for_scores(self, x: spec.Tensor) -> spec.Tensor:
+ new_x_shape = x.size()[:-1] + (self.num_heads, self.head_dim)
+ x = x.view(new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(self, x: spec.Tensor, dropout_rate=None) -> spec.Tensor:
+ if dropout_rate is None:
+ dropout_rate = self.dropout_rate
+
+ mixed_query_layer = self.query(x)
+
+ key_layer = self.transpose_for_scores(self.key(x))
+ value_layer = self.transpose_for_scores(self.value(x))
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+ attention_scores = attention_scores / math.sqrt(self.head_dim)
+
+ attention_probs = F.softmax(attention_scores, dim=-1)
+ attention_probs = F.dropout(attention_probs, dropout_rate, training=self.training)
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_dim,)
+ context_layer = context_layer.view(new_context_layer_shape)
+ out = self.out(context_layer)
+ return out
+
+
+class Encoder1DBlock(nn.Module):
+ """Single transformer encoder block (MHSA + MLP)."""
+
+ def __init__(self,
+ width: int,
+ mlp_dim: Optional[int] = None,
+ num_heads: int = 12,
+ use_glu: bool = False,
+ use_post_layer_norm: bool = False,
+ dropout_rate: float = 0.0) -> None:
+ super().__init__()
+
+ self.width = width
+ self.mlp_dim = mlp_dim
+ self.num_heads = num_heads
+ self.use_glu = use_glu
+ self.use_post_layer_norm = use_post_layer_norm
+ self.dropout_rate = dropout_rate
+
+ self.layer_norm0 = nn.LayerNorm(self.width, eps=1e-6)
+ self.self_attention1 = SelfAttention(self.width, self.num_heads)
+ self.layer_norm2 = nn.LayerNorm(self.width, eps=1e-6)
+ self.mlp3 = MlpBlock(
+ width=self.width,
+ mlp_dim=self.mlp_dim,
+ use_glu=self.use_glu,
+ dropout_rate=dropout_rate)
+
+ def forward(self, x: spec.Tensor, dropout_rate=None) -> spec.Tensor:
+ if dropout_rate is None:
+ dropout_rate = self.dropout_rate
+
+ if not self.use_post_layer_norm:
+ y = self.layer_norm0(x)
+ y = self.self_attention1(y)
+ y = F.dropout(y, dropout_rate, training=self.training)
+ x = x + y
+
+ y = self.layer_norm2(x)
+ y = self.mlp3(y)
+ y = F.dropout(y, dropout_rate, training=self.training)
+ x = x + y
+ else:
+ y = x
+ y = self.self_attention1(y)
+ y = F.dropout(y, dropout_rate, training=self.training)
+ x = x + y
+ x = self.layer_norm0(x)
+
+ y = x
+ y = self.mlp3(y)
+ y = F.dropout(y, dropout_rate, training=self.training)
+ x = x + y
+ x = self.layer_norm2(x)
+ return x
+
+
+class Encoder(nn.Module):
+ """Transformer Model Encoder for sequence to sequence translation."""
+
+ def __init__(self,
+ depth: int,
+ width: int,
+ mlp_dim: Optional[int] = None,
+ num_heads: int = 12,
+ use_glu: bool = False,
+ use_post_layer_norm: bool = False,
+ dropout_rate: float = 0.0) -> None:
+ super().__init__()
+
+ self.depth = depth
+ self.width = width
+ self.mlp_dim = mlp_dim
+ self.num_heads = num_heads
+ self.use_glu = use_glu
+ self.use_post_layer_norm = use_post_layer_norm
+
+ self.net = nn.ModuleList([
+ Encoder1DBlock(self.width,
+ self.mlp_dim,
+ self.num_heads,
+ self.use_glu,
+ self.use_post_layer_norm,
+ dropout_rate) for _ in range(depth)
+ ])
+
+ if not self.use_post_layer_norm:
+ self.encoder_norm = nn.LayerNorm(self.width, eps=1e-6)
+ else:
+ self.encoder_norm = None
+
+ def forward(self, x: spec.Tensor) -> spec.Tensor:
+ # Input Encoder.
+ for block in self.net:
+ x = block(x)
+ if not self.use_post_layer_norm:
+ return self.encoder_norm(x)
+ else:
+ return x
+
+
+class MAPHead(nn.Module):
+ """Multihead Attention Pooling."""
+
+ def __init__(self,
+ width: int,
+ mlp_dim: Optional[int] = None,
+ num_heads: int = 12):
+ super().__init__()
+ self.width = width
+ self.mlp_dim = mlp_dim
+ self.num_heads = num_heads
+
+ self.probe = nn.Parameter(torch.zeros((1, 1, self.width)))
+ nn.init.xavier_uniform_(self.probe.data)
+
+ self.mha = MultiheadAttention(
+ self.width, num_heads=self.num_heads, self_attn=False, bias=True)
+ self.layer_norm = nn.LayerNorm(self.width, eps=1e-6)
+ self.mlp = MlpBlock(width=self.width, mlp_dim=self.mlp_dim)
+
+ def forward(self, x: spec.Tensor) -> spec.Tensor:
+ n, _, _ = x.shape
+ probe = torch.tile(self.probe, [n, 1, 1])
+
+ x = self.mha(probe, x)[0]
+ y = self.layer_norm(x)
+ x = x + self.mlp(y)
+ return x[:, 0]
+
+
+class ViT(nn.Module):
+ """ViT model."""
+
+ image_height: int = 224
+ image_width: int = 224
+ channels: int = 3
+
+ def __init__(
+ self,
+ num_classes: int = 1000,
+ patch_size: Tuple[int, int] = (16, 16),
+ width: int = 768,
+ depth: int = 12,
+ mlp_dim: Optional[int] = None, # Defaults to 4x input dim.
+ num_heads: int = 12,
+ rep_size: Union[int, bool] = True,
+ dropout_rate: Optional[float] = 0.0,
+ head_zeroinit: bool = True,
+ use_glu: bool = False,
+ use_post_layer_norm: bool = False,
+ use_map: bool = False,
+ dtype: Any = torch.float32) -> None:
+ super().__init__()
+ if dropout_rate is None:
+ dropout_rate = 0.0
+
+ self.num_classes = num_classes
+ self.patch_size = patch_size
+ self.width = width
+ self.depth = depth
+ self.mlp_dim = mlp_dim
+ self.num_heads = num_heads
+ self.rep_size = rep_size
+ self.head_zeroinit = head_zeroinit
+ self.use_glu = use_glu
+ self.use_post_layer_norm = use_post_layer_norm
+ self.use_map = use_map
+ self.dtype = dtype
+ self.dropout_rate = dropout_rate
+
+ if self.rep_size:
+ rep_size = self.width if self.rep_size is True else self.rep_size
+ self.pre_logits = nn.Linear(self.width, rep_size)
+
+ self.conv_patch_extract = nn.Conv2d(
+ self.channels,
+ self.width,
+ self.patch_size,
+ stride=self.patch_size,
+ padding='valid')
+
+ self.encoder = Encoder(
+ depth=self.depth,
+ width=self.width,
+ mlp_dim=self.mlp_dim,
+ num_heads=self.num_heads,
+ use_glu=self.use_glu,
+ use_post_layer_norm=self.use_post_layer_norm,
+ dropout_rate=dropout_rate)
+
+ if self.num_classes:
+ self.head = nn.Linear(self.width, self.num_classes)
+
+ if self.use_map:
+ self.map = MAPHead(self.width, self.mlp_dim, self.num_heads)
+ else:
+ self.map = None
+
+ self.reset_parameters()
+
+ def reset_parameters(self) -> None:
+ init_utils.pytorch_default_init(self.conv_patch_extract)
+
+ if self.rep_size:
+ init_utils.pytorch_default_init(self.pre_logits)
+
+ if self.num_classes:
+ if self.head_zeroinit:
+ nn.init.constant_(self.head.weight.data, 0.)
+ nn.init.constant_(self.head.bias.data, 0.)
+ else:
+ init_utils.pytorch_default_init(self.head)
+
+ def get_posemb(self, x: spec.Tensor) -> spec.Tensor:
+ return posemb_sincos_2d(x).type(self.dtype)
+
+ def forward(self, x: spec.Tensor, dropout_rate=None) -> spec.Tensor:
+ if dropout_rate is None:
+ dropout_rate = self.dropout_rate
+
+ # Patch extraction.
+ x = self.conv_patch_extract(x)
+
+ # Add posemb before adding extra token.
+ n, c, h, w = x.shape
+ pes = self.get_posemb(x)
+
+ # Reshape to match Jax's ViT implementation.
+ x = torch.transpose(torch.reshape(x, (n, c, h * w)), 1, 2)
+ x = x + pes
+
+ x = F.dropout(x, dropout_rate, training=self.training)
+ x = self.encoder(x)
+
+ if self.use_map:
+ x = self.map(x)
+ else:
+ x = torch.mean(x, dim=1)
+
+ if self.rep_size:
+ x = torch.tanh(self.pre_logits(x))
+
+ if self.num_classes:
+ x = self.head(x)
+
+ return x
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py
new file mode 100644
index 000000000..da66dfe43
--- /dev/null
+++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py
@@ -0,0 +1,518 @@
+"""This is a pytorch implementation mirroring:
+https://github.com/google/init2winit/blob/master/init2winit/model_lib/conformer.py.
+"""
+
+from dataclasses import dataclass
+from functools import partial
+import math
+from typing import Optional, Tuple
+
+import torch
+from torch import nn
+from torch.nn import init
+import torch.nn.functional as F
+
+from algoperf.workloads.librispeech_conformer.librispeech_pytorch import \
+ preprocessor
+from algoperf.workloads.librispeech_conformer.librispeech_pytorch.spectrum_augmenter import \
+ SpecAug
+
+
+@dataclass
+class ConformerConfig:
+ """Global hyperparameters used to minimize obnoxious kwarg plumbing."""
+ vocab_size: int = 1024
+ encoder_dim: int = 512
+ num_attention_heads: int = 8
+ num_encoder_layers: int = 4
+ attention_dropout_rate: float = 0.0
+ # If None, defaults to 0.1.
+ attention_residual_dropout_rate: Optional[float] = 0.1
+ # If None, defaults to 0.0.
+ conv_residual_dropout_rate: Optional[float] = 0.0
+ feed_forward_dropout_rate: float = 0.0
+ # If None, defaults to 0.1.
+ feed_forward_residual_dropout_rate: Optional[float] = 0.1
+ convolution_kernel_size: int = 5
+ feed_forward_expansion_factor: int = 4
+ freq_mask_count: int = 2
+ freq_mask_max_bins: int = 27
+ time_mask_count: int = 10
+ time_mask_max_frames: int = 40
+ time_mask_max_ratio: float = 0.05
+ time_masks_per_frame: float = 0.0
+ use_dynamic_time_mask_max_frames: bool = True
+ # If None, defaults to 0.1.
+ input_dropout_rate: Optional[float] = 0.1
+ batch_norm_momentum: float = 1 - 0.999
+ batch_norm_epsilon: float = 0.001
+ use_specaug: bool = True
+ attention_temperature: float = 1.0
+ activation_function_name: str = 'swish'
+ use_post_layer_norm: bool = True
+
+
+def initialize(m):
+ if isinstance(m, nn.Linear) or isinstance(m, nn.Conv1d):
+ init.xavier_uniform_(m.weight)
+ if m.bias is not None:
+ init.constant_(m.bias, 0)
+ elif isinstance(m, nn.MultiheadAttention):
+ init.xavier_uniform_(m.in_proj_weight)
+ for i in m.children():
+ initialize(i)
+
+
+class LayerNorm(nn.Module):
+
+ def __init__(self, dim, epsilon=1e-6):
+ super().__init__()
+ self.dim = dim
+
+ self.scale = nn.Parameter(torch.zeros(self.dim))
+ self.bias = nn.Parameter(torch.zeros(self.dim))
+ self.epsilon = epsilon
+
+ def forward(self, x):
+ return F.layer_norm(x, (self.dim,), 1 + self.scale, self.bias, self.epsilon)
+
+
+class Subsample(nn.Module):
+
+ def __init__(self,
+ encoder_dim: int = 0,
+ input_dropout_rate: float = 0.0,
+ num_bins: int = 80):
+ super().__init__()
+ self.encoder_dim = encoder_dim
+ self.input_dropout_rate = input_dropout_rate
+
+ self.conv1 = Conv2dSubsampling(
+ input_channels=1, output_channels=encoder_dim)
+ self.conv2 = Conv2dSubsampling(
+ input_channels=encoder_dim, output_channels=encoder_dim)
+
+ self.linear = nn.Linear(
+ in_features=self.encoder_dim * num_bins // 4,
+ out_features=self.encoder_dim,
+ bias=True)
+ self.pos_encode = AddPositionalEmbedding(embedding_dim=self.encoder_dim)
+
+ def forward(self, inputs, input_paddings, dropout_rate=None):
+ if dropout_rate is None:
+ dropout_rate = self.input_dropout_rate
+
+ output_paddings = input_paddings
+ outputs = inputs[:, None, :, :]
+
+ outputs, output_paddings = self.conv1(outputs, output_paddings)
+ outputs, output_paddings = self.conv2(outputs, output_paddings)
+
+ batch_size, channels, subsampled_lengths, subsampled_dims = outputs.shape
+ outputs = outputs.permute(0, 2, 3, 1).reshape(batch_size,
+ subsampled_lengths,
+ subsampled_dims * channels)
+
+ outputs = self.linear(outputs)
+ outputs = outputs + self.pos_encode(seq_length=outputs.shape[1])
+ outputs = F.dropout(outputs, dropout_rate, training=self.training, inplace=True)
+
+ return outputs, output_paddings
+
+
+class Conv2dSubsampling(nn.Module):
+
+ def __init__(self,
+ input_channels: int,
+ output_channels: int,
+ filter_stride: Tuple[int] = (2, 2),
+ padding: str = 'SAME'):
+ super().__init__()
+
+ self.input_channels = input_channels
+ self.output_channels = output_channels
+ self.filter_stride = filter_stride
+ self.padding = padding
+
+ self.filter_shape = (output_channels, input_channels, 3, 3)
+
+ self.kernel = nn.Parameter(
+ torch.nn.init.xavier_uniform_(torch.empty(*self.filter_shape)))
+ self.bias = nn.Parameter(torch.zeros(output_channels))
+ self.register_buffer('paddings_kernel', torch.ones([1, 1, 1]))
+
+ def get_same_padding(self, input_shape):
+ in_height, in_width = input_shape[2:]
+ stride_height, stride_width = self.filter_stride
+ filter_height, filter_width = 3, 3
+ if in_height % stride_height == 0:
+ pad_along_height = max(filter_height - stride_height, 0)
+ else:
+ pad_along_height = max(filter_height - (in_height % stride_height), 0)
+ if in_width % stride_width == 0:
+ pad_along_width = max(filter_width - stride_width, 0)
+ else:
+ pad_along_width = max(filter_width - (in_width % stride_width), 0)
+ pad_top = pad_along_height // 2
+ pad_bottom = pad_along_height - pad_top
+ pad_left = pad_along_width // 2
+ pad_right = pad_along_width - pad_left
+ return (pad_left, pad_right, pad_top, pad_bottom)
+
+ def forward(self, inputs, paddings):
+ groups = inputs.shape[1] // self.input_channels
+
+ if self.padding == 'SAME':
+ in_ = F.pad(inputs, self.get_same_padding(inputs.shape))
+ else:
+ in_ = inputs
+ outputs = F.conv2d(
+ input=in_,
+ weight=self.kernel,
+ bias=self.bias,
+ stride=self.filter_stride,
+ dilation=(1, 1),
+ groups=groups)
+
+ outputs = F.relu(outputs)
+
+ input_length = paddings.shape[1]
+ stride = self.filter_stride[0]
+ pad_len = (input_length + stride - 1) // stride * stride - input_length
+ padded_paddings = F.pad(
+ paddings[:, None, :], (0, pad_len), mode='constant', value=0)
+ out_padding = F.conv1d(
+ input=padded_paddings,
+ weight=self.paddings_kernel,
+ stride=self.filter_stride[:1])
+ out_padding = out_padding.squeeze(dim=1)
+ outputs = outputs * (1 - out_padding[:, None, :, None])
+ return outputs, out_padding
+
+
+class FeedForwardModule(nn.Module):
+
+ def __init__(self, config: ConformerConfig):
+ super().__init__()
+ self.config = config
+
+ self.ln = LayerNorm(dim=config.encoder_dim)
+ self.linear1 = nn.Linear(
+ in_features=config.encoder_dim,
+ out_features=config.encoder_dim * config.feed_forward_expansion_factor,
+ bias=True)
+ self.dropout1 = nn.Dropout(p=config.feed_forward_dropout_rate, inplace=True)
+ self.linear2 = nn.Linear(
+ in_features=config.encoder_dim * config.feed_forward_expansion_factor,
+ out_features=config.encoder_dim,
+ bias=True)
+
+ if config.feed_forward_residual_dropout_rate is None:
+ self.feed_forward_residual_dropout_rate = 0.1
+ else:
+ self.feed_forward_residual_dropout_rate = config.feed_forward_residual_dropout_rate
+
+ def forward(self, inputs, padding_mask, dropout_rate=None):
+ if dropout_rate is None:
+ dropout_rate = self.feed_forward_residual_dropout_rate
+
+ inputs = self.ln(inputs)
+ inputs = self.linear1(inputs)
+ if self.config.activation_function_name == 'swish':
+ activation_fn = F.silu
+ elif self.config.activation_function_name == 'gelu':
+ # Use tanh approximation of GELU which is default for jax
+ activation_fn = partial(F.gelu, approximate='tanh')
+ else:
+ raise ValueError('Only "swish" and "gelu" are supported '
+ 'config.activation_function_name values, recieved '
+ f'{self.config.activation_function_name}')
+ inputs = activation_fn(inputs)
+ inputs = self.dropout1(inputs)
+ inputs = inputs * padding_mask
+ inputs = self.linear2(inputs)
+ inputs = inputs * padding_mask
+ inputs = F.dropout(inputs, dropout_rate, training=self.training, inplace=True)
+
+ return inputs
+
+
+class AddPositionalEmbedding(nn.Module):
+
+ def __init__(self,
+ min_timescale: int = 1,
+ max_timescale: int = 10_000,
+ embedding_dim: int = 512):
+ super().__init__()
+ self.min_timescale = min_timescale
+ self.max_timescale = max_timescale
+ self.embedding_dim = embedding_dim
+ num_timescales = self.embedding_dim // 2
+ log_timescale_increment = math.log(
+ float(self.max_timescale) / float(self.min_timescale)) / (
+ num_timescales - 1)
+ inv_timescales = self.min_timescale * \
+ torch.exp(torch.arange(num_timescales, dtype=torch.float32)
+ * -log_timescale_increment)
+ self.register_buffer('inv_timescales', inv_timescales[None, None, :])
+
+ def forward(self, seq_length):
+ position = torch.arange(
+ end=seq_length, dtype=torch.float32, device=self.inv_timescales.device)
+ scaled_time = position[None, :, None] * self.inv_timescales
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2)
+ if self.embedding_dim % 2:
+ signal = torch.cat(
+ [signal, torch.zeros(signal.shape[0], signal.shape[1], 1)], dim=2)
+ return signal
+
+
+class QueryScaler(nn.Module):
+
+ def __init__(self, dim):
+ super().__init__()
+ self.dim = dim
+ self.scale = nn.Parameter(torch.zeros(self.dim))
+
+ def forward(self, inputs):
+ r_softplus_0 = 1.442695041
+ scale = r_softplus_0 * F.softplus(self.scale)
+ return inputs * scale
+
+
+class MHSAwithQS(nn.Module):
+
+ def __init__(self, config: ConformerConfig):
+ super().__init__()
+ self.embed_dim = config.encoder_dim
+ self.num_heads = config.num_attention_heads
+ self.attention_dropout_rate = config.attention_dropout_rate
+ self.in_proj = nn.Linear(config.encoder_dim, 3 * config.encoder_dim)
+ self.out_proj = nn.Linear(config.encoder_dim, config.encoder_dim)
+ self.qs = QueryScaler(dim=config.encoder_dim // config.num_attention_heads)
+ self.attention_temperature = config.attention_temperature
+
+ def forward(self, inputs, key_padding_mask=None):
+ batch_size, seq_len, embed_dim = inputs.shape
+ q, k, v = self.in_proj(inputs).split(self.embed_dim, dim=2)
+ q = self.qs(q.view(batch_size, seq_len, self.num_heads, -1)).transpose(1, 2)
+ k = k.view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)
+ v = v.view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)
+ out = F.scaled_dot_product_attention(
+ query=q,
+ key=k,
+ value=v,
+ attn_mask=~key_padding_mask[:, None, None],
+ dropout_p=self.attention_dropout_rate,
+ ).transpose(1, 2).reshape(batch_size, seq_len, embed_dim)
+ out = out * self.attention_temperature
+ out = self.out_proj(out)
+ return out
+
+
+class MultiHeadedSelfAttention(nn.Module):
+
+ def __init__(self, config: ConformerConfig):
+ super().__init__()
+
+ self.config = config
+
+ self.ln = LayerNorm(dim=config.encoder_dim)
+ self.self_attention = MHSAwithQS(config)
+ if config.attention_residual_dropout_rate is None:
+ self.attention_residual_dropout_rate = 0.1
+ else:
+ self.attention_residual_dropout_rate = config.attention_residual_dropout_rate
+
+ def forward(self, outputs, paddings, dropout_rate=None):
+ if dropout_rate is None:
+ dropout_rate = self.attention_residual_dropout_rate
+
+ outputs = self.ln(outputs)
+ outputs = self.self_attention(
+ outputs,
+ key_padding_mask=paddings == 1,
+ )
+ outputs = F.dropout(outputs, dropout_rate, training=self.training, inplace=True)
+ return outputs
+
+
+class BatchNorm(nn.Module):
+
+ def __init__(self, config: ConformerConfig):
+ super().__init__()
+ running_mean = torch.zeros(config.encoder_dim)
+ running_var = torch.ones(config.encoder_dim)
+ self.register_buffer('running_mean', running_mean)
+ self.register_buffer('running_var', running_var)
+ self.scale = nn.Parameter(torch.zeros(config.encoder_dim))
+ self.bias = nn.Parameter(torch.zeros(config.encoder_dim))
+
+ self.register_buffer('dim', torch.FloatTensor([config.encoder_dim]))
+ self.momentum = config.batch_norm_momentum
+ self.epsilon = config.batch_norm_epsilon
+
+ def forward(self, inputs, input_paddings):
+ #inputs: NHD
+ #padding: NH
+ """
+ Alternatively:
+ inputs[input_paddings==0] = F.batch_norm(
+ input = inputs[input_paddings==0],
+ running_mean = self.running_mean,
+ running_var = self.running_var,
+ weight = 1+self.scale,
+ bias = self.bias,
+ training = self.training,
+ momentum=1-self.momentum,
+ eps=self.epsilon
+ )
+ inputs.masked_fill(input_paddings[...,None] != 0, 0)
+ return inputs
+ """
+ mask = 1 - input_paddings[:, :, None]
+ if self.training:
+ count = mask.sum()
+ masked_inp = inputs.masked_fill(mask == 0, 0)
+ mean = (masked_inp).sum(dim=(0, 1)) / count
+ var = (torch.square(masked_inp - mean) * mask).sum(dim=(0, 1)) / count
+
+ self.running_mean = (1 - self.momentum) * self.running_mean + (
+ self.momentum) * mean.detach()
+ self.running_var = (1 - self.momentum) * self.running_var + (
+ self.momentum) * var.detach()
+
+ else:
+ mean = self.running_mean
+ var = self.running_var
+ v = (1 + self.scale) * torch.rsqrt(var + self.epsilon)
+ bn = (inputs - mean) * v + self.bias
+ output = bn.masked_fill(mask == 0, 0)
+ return output
+
+
+class ConvolutionBlock(nn.Module):
+
+ def __init__(self, config):
+ super().__init__()
+
+ self.config = config
+ self.ln = LayerNorm(dim=config.encoder_dim)
+ self.lin1 = nn.Linear(
+ in_features=config.encoder_dim, out_features=config.encoder_dim)
+ self.lin2 = nn.Linear(
+ in_features=config.encoder_dim, out_features=config.encoder_dim)
+
+ self.conv1 = nn.Conv1d(
+ in_channels=config.encoder_dim,
+ out_channels=config.encoder_dim,
+ kernel_size=(config.convolution_kernel_size,),
+ stride=(1,),
+ padding='same',
+ bias=False,
+ groups=config.encoder_dim)
+ self.bn = BatchNorm(config)
+ self.lin3 = nn.Linear(config.encoder_dim, config.encoder_dim)
+ if config.conv_residual_dropout_rate is None:
+ self.conv_residual_dropout_rate = 0.0
+ else:
+ self.conv_residual_dropout_rate = config.conv_residual_dropout_rate
+
+ def forward(self, inputs, input_paddings, dropout_rate=None):
+ if dropout_rate is None:
+ dropout_rate = self.conv_residual_dropout_rate
+
+ inputs = self.ln(inputs)
+
+ inputs = F.glu(torch.cat([self.lin1(inputs), self.lin2(inputs)], dim=2))
+ inputs = inputs * (1 - input_paddings[:, :, None])
+
+ inputs = inputs.permute(0, 2, 1)
+ inputs = self.conv1(inputs)
+ inputs = inputs.permute(0, 2, 1)
+
+ inputs = self.bn(inputs, input_paddings)
+ if self.config.activation_function_name == 'swish':
+ activation_fn = F.silu
+ elif self.config.activation_function_name == 'gelu':
+ activation_fn = F.gelu
+ else:
+ raise ValueError('Only "swish" and "gelu" are supported '
+ 'config.activation_function_name values, recieved '
+ f'{self.config.activation_function_name}')
+ inputs = activation_fn(inputs)
+ inputs = self.lin3(inputs)
+
+ inputs = F.dropout(inputs, dropout_rate, training=self.training, inplace=True)
+ return inputs
+
+
+class ConformerBlock(nn.Module):
+
+ def __init__(self, config: ConformerConfig):
+ super().__init__()
+
+ self.ff1 = FeedForwardModule(config)
+ self.mhsa = MultiHeadedSelfAttention(config)
+ self.conv = ConvolutionBlock(config)
+ self.ff2 = FeedForwardModule(config)
+ self.ln = None
+ if config.use_post_layer_norm:
+ self.ln = LayerNorm(dim=config.encoder_dim)
+
+ def forward(self, inputs, input_paddings, dropout_rate=None):
+ padding_mask = 1 - input_paddings[:, :, None]
+ inputs = inputs + 0.5 * self.ff1(inputs, padding_mask, dropout_rate)
+ inputs = inputs + self.mhsa(inputs, input_paddings, dropout_rate)
+ inputs = inputs + self.conv(inputs, input_paddings, dropout_rate)
+ inputs = inputs + 0.5 * self.ff2(inputs, padding_mask, dropout_rate)
+ if self.ln:
+ inputs = self.ln(inputs)
+ return inputs
+
+
+class ConformerEncoderDecoder(nn.Module):
+
+ def __init__(self, config: ConformerConfig):
+ super().__init__()
+ self.config = config
+ preprocessing_config = preprocessor.PreprocessorConfig()
+ self.preprocessor = preprocessor.MelFilterbankFrontend(
+ preprocessing_config,
+ per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR,
+ per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR)
+ self.specaug = SpecAug(
+ freq_mask_count=config.freq_mask_count,
+ freq_mask_max_bins=config.freq_mask_max_bins,
+ time_mask_count=config.time_mask_count,
+ time_mask_max_frames=config.time_mask_max_frames,
+ time_mask_max_ratio=config.time_mask_max_ratio,
+ time_masks_per_frame=config.time_masks_per_frame,
+ use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames
+ )
+ if config.input_dropout_rate is None:
+ input_dropout_rate = 0.1
+ else:
+ input_dropout_rate = config.input_dropout_rate
+ self.subsample = Subsample(
+ encoder_dim=config.encoder_dim,
+ input_dropout_rate=input_dropout_rate,
+ num_bins=preprocessing_config.num_bins)
+ self.conformers = nn.ModuleList(
+ [ConformerBlock(config) for _ in range(config.num_encoder_layers)])
+
+ self.ln = LayerNorm(config.encoder_dim)
+ self.lin = nn.Linear(config.encoder_dim, config.vocab_size)
+
+ def forward(self, inputs, input_paddings, dropout_rate=None):
+ outputs = inputs
+ output_paddings = input_paddings
+ outputs, output_paddings = self.preprocessor(outputs, output_paddings)
+ if self.training and self.config.use_specaug:
+ outputs, output_paddings = self.specaug(outputs, output_paddings)
+ outputs, output_paddings = self.subsample(outputs, output_paddings)
+ for conformer in self.conformers:
+ outputs = conformer(outputs, output_paddings, dropout_rate)
+ outputs = self.ln(outputs)
+ outputs = self.lin(outputs)
+ return outputs, output_paddings
From 3e7a3967ba5f4845826e971ba5f0c49fa4c031b9 Mon Sep 17 00:00:00 2001
From: Niccolo-Ajroldi
Date: Wed, 11 Jun 2025 11:22:23 +0200
Subject: [PATCH 033/123] dropout fix deepspeech, ogbg
---
algoperf/workloads/dropout_modules.py | 5 +-
.../librispeech_pytorch/models_dropout.py | 395 ++++++++++++++++++
.../ogbg/ogbg_pytorch/models_dropout.py | 314 ++++++++++++++
3 files changed, 711 insertions(+), 3 deletions(-)
create mode 100644 algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py
create mode 100644 algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py
diff --git a/algoperf/workloads/dropout_modules.py b/algoperf/workloads/dropout_modules.py
index 3917b75bf..6cec3f7ad 100644
--- a/algoperf/workloads/dropout_modules.py
+++ b/algoperf/workloads/dropout_modules.py
@@ -31,10 +31,9 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._supports_custom_dropout = True
- def forward(self, x, p):
+ def forward(self, x: Tensor, p: float) -> Tensor:
for module in self:
- # if isinstance(module, (CustomDropout, SequentialWithDropout, DenseBlockWithDropout)):
- if getattr(module, '_supports_custom_dropout', False): # TODO (nico): improve
+ if getattr(module, '_supports_custom_dropout', False):
x = module(x, p)
else:
x = module(x)
diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py
new file mode 100644
index 000000000..e68a820ed
--- /dev/null
+++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py
@@ -0,0 +1,395 @@
+"""This is a pytorch implementation mirroring:
+https://github.com/google/init2winit/blob/master/init2winit/model_lib/conformer.py.
+"""
+
+from dataclasses import dataclass
+import os
+from typing import Optional, Tuple
+
+import torch
+from torch import nn
+import torch.distributed.nn as dist_nn
+import torch.nn.functional as F
+
+from algoperf.workloads.librispeech_conformer.librispeech_pytorch import \
+ preprocessor
+from algoperf.workloads.librispeech_conformer.librispeech_pytorch.spectrum_augmenter import \
+ SpecAug
+
+USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ
+
+
+@dataclass
+class DeepspeechConfig:
+ """Global hyperparameters used to minimize obnoxious kwarg plumbing."""
+ vocab_size: int = 1024
+ encoder_dim: int = 512
+ num_lstm_layers: int = 6
+ num_ffn_layers: int = 3
+ conv_subsampling_factor: int = 2
+ conv_subsampling_layers: int = 2
+ use_specaug: bool = True
+ freq_mask_count: int = 2
+ freq_mask_max_bins: int = 27
+ time_mask_count: int = 10
+ time_mask_max_frames: int = 40
+ time_mask_max_ratio: float = 0.05
+ time_masks_per_frame: float = 0.0
+ use_dynamic_time_mask_max_frames: bool = True
+ batch_norm_momentum: float = 1 - 0.999
+ batch_norm_epsilon: float = 0.001
+ # If None, defaults to 0.1.
+ input_dropout_rate: Optional[float] = 0.1
+ # If None, defaults to 0.1.
+ feed_forward_dropout_rate: Optional[float] = 0.1
+ enable_residual_connections: bool = True
+ enable_decoder_layer_norm: bool = True
+ bidirectional: bool = True
+ use_tanh: bool = False
+ layernorm_everywhere: bool = False
+
+
+class LayerNorm(nn.Module):
+
+ def __init__(self, dim, epsilon=1e-6):
+ super().__init__()
+ self.dim = dim
+
+ self.scale = nn.Parameter(torch.zeros(self.dim))
+ self.bias = nn.Parameter(torch.zeros(self.dim))
+ self.epsilon = epsilon
+
+ def forward(self, x):
+ mean = x.mean(dim=-1, keepdims=True)
+ var = x.var(dim=-1, unbiased=False, keepdims=True)
+
+ normed_x = (x - mean) * torch.rsqrt(var + self.epsilon)
+ normed_x *= (1 + self.scale)
+ normed_x += self.bias
+
+ return normed_x
+
+
+class Subsample(nn.Module):
+
+ def __init__(self, config: DeepspeechConfig):
+ super().__init__()
+ encoder_dim = config.encoder_dim
+
+ self.encoder_dim = encoder_dim
+
+ self.conv1 = Conv2dSubsampling(
+ input_channels=1, output_channels=encoder_dim, use_tanh=config.use_tanh)
+ self.conv2 = Conv2dSubsampling(
+ input_channels=encoder_dim,
+ output_channels=encoder_dim,
+ use_tanh=config.use_tanh)
+
+ self.lin = nn.LazyLinear(out_features=self.encoder_dim, bias=True)
+
+ if config.input_dropout_rate is None:
+ self.input_dropout_rate = 0.1
+ else:
+ self.input_dropout_rate = config.input_dropout_rate
+
+ def forward(self, inputs, input_paddings, dropout_rate=None):
+ if dropout_rate is None:
+ dropout_rate = self.input_dropout_rate
+
+ output_paddings = input_paddings
+ outputs = inputs[:, None, :, :]
+
+ outputs, output_paddings = self.conv1(outputs, output_paddings)
+ outputs, output_paddings = self.conv2(outputs, output_paddings)
+
+ batch_size, channels, subsampled_lengths, subsampled_dims = outputs.shape
+ outputs = outputs.permute(0, 2, 3, 1).reshape(batch_size,
+ subsampled_lengths,
+ subsampled_dims * channels)
+
+ outputs = self.lin(outputs)
+ outputs = F.dropout(outputs, dropout_rate, training=self.training)
+
+ return outputs, output_paddings
+
+
+class Conv2dSubsampling(nn.Module):
+
+ def __init__(self,
+ input_channels: int,
+ output_channels: int,
+ filter_stride: Tuple[int] = (2, 2),
+ padding: str = 'SAME',
+ batch_norm_momentum: float = 0.999,
+ batch_norm_epsilon: float = 0.001,
+ use_tanh: bool = False):
+ super().__init__()
+
+ self.input_channels = input_channels
+ self.output_channels = output_channels
+ self.filter_stride = filter_stride
+ self.padding = padding
+
+ self.filter_shape = (output_channels, input_channels, 3, 3)
+
+ self.kernel = nn.Parameter(
+ nn.init.xavier_uniform_(torch.empty(*self.filter_shape)))
+ self.bias = nn.Parameter(torch.zeros(output_channels))
+
+ self.use_tanh = use_tanh
+
+ def get_same_padding(self, input_shape):
+ in_height, in_width = input_shape[2:]
+ stride_height, stride_width = self.filter_stride
+ filter_height, filter_width = 3, 3
+ if in_height % stride_height == 0:
+ pad_along_height = max(filter_height - stride_height, 0)
+ else:
+ pad_along_height = max(filter_height - (in_height % stride_height), 0)
+ if in_width % stride_width == 0:
+ pad_along_width = max(filter_width - stride_width, 0)
+ else:
+ pad_along_width = max(filter_width - (in_width % stride_width), 0)
+ pad_top = pad_along_height // 2
+ pad_bottom = pad_along_height - pad_top
+ pad_left = pad_along_width // 2
+ pad_right = pad_along_width - pad_left
+ return (pad_left, pad_right, pad_top, pad_bottom)
+
+ def forward(self, inputs, paddings):
+ groups = inputs.shape[1] // self.input_channels
+
+ if self.padding == 'SAME':
+ in_ = F.pad(inputs, self.get_same_padding(inputs.shape))
+ else:
+ in_ = inputs
+ outputs = F.conv2d(
+ input=in_,
+ weight=self.kernel,
+ bias=self.bias,
+ stride=self.filter_stride,
+ dilation=(1, 1),
+ groups=groups)
+
+ if self.use_tanh:
+ outputs = F.tanh(outputs)
+ else:
+ outputs = F.relu(outputs)
+
+ input_length = paddings.shape[1]
+ stride = self.filter_stride[0]
+ pad_len = (input_length + stride - 1) // stride * stride - input_length
+ out_padding = F.conv1d(
+ input=torch.cat([
+ paddings[:, None, :],
+ torch.zeros(
+ size=(paddings.shape[0], 1, pad_len), device=paddings.device)
+ ],
+ dim=2),
+ weight=torch.ones([1, 1, 1], device=paddings.device),
+ stride=self.filter_stride[:1])
+ out_padding = out_padding.squeeze(dim=1)
+ outputs = outputs * (1 - out_padding[:, None, :, None])
+ return outputs, out_padding
+
+
+class FeedForwardModule(nn.Module):
+
+ def __init__(self, config: DeepspeechConfig):
+ super().__init__()
+ self.config = config
+
+ if config.layernorm_everywhere:
+ self.normalization_layer = LayerNorm(config.encoder_dim)
+ else:
+ self.bn_normalization_layer = BatchNorm(
+ dim=config.encoder_dim,
+ batch_norm_momentum=config.batch_norm_momentum,
+ batch_norm_epsilon=config.batch_norm_epsilon)
+ self.lin = nn.LazyLinear(out_features=config.encoder_dim, bias=True)
+ if config.feed_forward_dropout_rate is None:
+ self.feed_forward_dropout_rate = 0.1
+ else:
+ self.feed_forward_dropout_rate = config.feed_forward_dropout_rate
+
+ def forward(self, inputs, input_paddings, dropout_rate=None):
+ if dropout_rate is None:
+ dropout_rate = self.feed_forward_dropout_rate
+
+ padding_mask = (1 - input_paddings)[:, :, None]
+ if self.config.layernorm_everywhere:
+ inputs = self.normalization_layer(inputs)
+ else: # batchnorm
+ inputs = self.bn_normalization_layer(inputs, input_paddings)
+
+ inputs = self.lin(inputs)
+
+ if self.config.use_tanh:
+ inputs = F.tanh(inputs)
+ else:
+ inputs = F.relu(inputs)
+
+ inputs = inputs * padding_mask
+ inputs = F.dropout(inputs, dropout_rate, training=self.training)
+
+ return inputs
+
+
+class BatchNorm(nn.Module):
+
+ def __init__(self, dim, batch_norm_momentum, batch_norm_epsilon):
+ super().__init__()
+ running_mean = torch.zeros(dim)
+ running_var = torch.ones(dim)
+ self.register_buffer('running_mean', running_mean)
+ self.register_buffer('running_var', running_var)
+ self.weight = nn.Parameter(torch.zeros(dim))
+ self.bias = nn.Parameter(torch.zeros(dim))
+
+ self.momentum = batch_norm_momentum
+ self.epsilon = batch_norm_epsilon
+ self.dim = dim
+
+ def forward(self, inputs, input_paddings):
+ #inputs: NHD
+ #padding: NH
+ mask = 1 - input_paddings[:, :, None]
+ if self.training:
+ count = mask.sum()
+ masked_inp = inputs.masked_fill(mask == 0, 0)
+ sum_ = (masked_inp).sum(dim=(0, 1))
+ if USE_PYTORCH_DDP:
+ sum_ = dist_nn.all_reduce(sum_)
+ count = dist_nn.all_reduce(count)
+ mean = sum_ / count
+
+ sum_ = (torch.square(masked_inp - mean) * mask).sum(dim=(0, 1))
+ if USE_PYTORCH_DDP:
+ sum_ = dist_nn.all_reduce(sum_)
+ var = sum_ / count
+
+ self.running_mean = (1 - self.momentum) * self.running_mean + (
+ self.momentum) * mean.detach()
+ self.running_var = (1 - self.momentum) * self.running_var + (
+ self.momentum) * var.detach()
+ else:
+ mean = self.running_mean
+ var = self.running_var
+ v = (1 + self.weight) * torch.rsqrt(var + self.epsilon)
+ bn = (inputs - mean) * v + self.bias
+ output = bn.masked_fill(mask == 0, 0)
+ return output
+
+
+class BatchRNN(nn.Module):
+
+ def __init__(self, config: DeepspeechConfig):
+ super().__init__()
+ self.config = config
+ hidden_size = config.encoder_dim
+ input_size = config.encoder_dim
+ bidirectional = config.bidirectional
+ self.bidirectional = bidirectional
+
+ if config.layernorm_everywhere:
+ self.normalization_layer = LayerNorm(config.encoder_dim)
+ else:
+ self.bn_normalization_layer = BatchNorm(config.encoder_dim,
+ config.batch_norm_momentum,
+ config.batch_norm_epsilon)
+
+ if bidirectional:
+ self.lstm = nn.LSTM(
+ input_size=input_size,
+ hidden_size=hidden_size // 2,
+ bidirectional=True,
+ batch_first=True)
+ else:
+ self.lstm = nn.LSTM(
+ input_size=input_size, hidden_size=hidden_size, batch_first=True)
+
+ def forward(self, inputs, input_paddings):
+ if self.config.layernorm_everywhere:
+ inputs = self.normalization_layer(inputs)
+ else:
+ inputs = self.bn_normalization_layer(inputs, input_paddings)
+ lengths = torch.sum(1 - input_paddings, dim=1).detach().cpu().numpy()
+ packed_inputs = torch.nn.utils.rnn.pack_padded_sequence(
+ inputs, lengths, batch_first=True, enforce_sorted=False)
+ packed_outputs, _ = self.lstm(packed_inputs)
+ outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(
+ packed_outputs, batch_first=True)
+ if outputs.shape[1] < inputs.shape[1]:
+ outputs = torch.cat([
+ outputs,
+ torch.zeros(
+ size=(outputs.shape[0],
+ inputs.shape[1] - outputs.shape[1],
+ outputs.shape[2]),
+ device=outputs.device)
+ ],
+ dim=1)
+ return outputs
+
+
+class DeepspeechEncoderDecoder(nn.Module):
+
+ def __init__(self, config: DeepspeechConfig):
+ super().__init__()
+ self.config = config
+
+ self.specaug = SpecAug(
+ freq_mask_count=config.freq_mask_count,
+ freq_mask_max_bins=config.freq_mask_max_bins,
+ time_mask_count=config.time_mask_count,
+ time_mask_max_frames=config.time_mask_max_frames,
+ time_mask_max_ratio=config.time_mask_max_ratio,
+ time_masks_per_frame=config.time_masks_per_frame,
+ use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames
+ )
+ preprocessing_config = preprocessor.PreprocessorConfig()
+ self.preprocessor = preprocessor.MelFilterbankFrontend(
+ preprocessing_config,
+ per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR,
+ per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR)
+
+ self.subsample = Subsample(config=config)
+
+ self.lstms = nn.ModuleList(
+ [BatchRNN(config) for _ in range(config.num_lstm_layers)])
+ self.ffns = nn.ModuleList(
+ [FeedForwardModule(config) for _ in range(config.num_ffn_layers)])
+
+ if config.enable_decoder_layer_norm:
+ self.ln = LayerNorm(config.encoder_dim)
+ else:
+ self.ln = nn.Identity()
+
+ self.lin = nn.Linear(config.encoder_dim, config.vocab_size)
+
+ def forward(self, inputs, input_paddings, dropout_rate=None):
+ outputs = inputs
+ output_paddings = input_paddings
+
+ outputs, output_paddings = self.preprocessor(outputs, output_paddings)
+ if self.training and self.config.use_specaug:
+ outputs, output_paddings = self.specaug(outputs, output_paddings)
+ outputs, output_paddings = self.subsample(outputs, output_paddings, dropout_rate)
+ for idx in range(self.config.num_lstm_layers):
+ if self.config.enable_residual_connections:
+ outputs = outputs + self.lstms[idx](outputs, output_paddings)
+ else:
+ outputs = self.lstms[idx](outputs, output_paddings)
+
+ for idx in range(self.config.num_ffn_layers):
+ if self.config.enable_residual_connections:
+ outputs = outputs + self.ffns[idx](outputs, output_paddings)
+ else:
+ outputs = self.ffns[idx](outputs, output_paddings, dropout_rate)
+
+ if self.config.enable_decoder_layer_norm:
+ outputs = self.ln(outputs)
+
+ outputs = self.lin(outputs)
+
+ return outputs, output_paddings
diff --git a/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py b/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py
new file mode 100644
index 000000000..1d89ea9e7
--- /dev/null
+++ b/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py
@@ -0,0 +1,314 @@
+# Ported to PyTorch from
+# https://github.com/google/init2winit/blob/master/init2winit/model_lib/gnn.py.
+from functools import partial
+from typing import Callable, Optional, Tuple
+
+import jax.tree_util as tree
+from jraph import GraphsTuple
+import torch
+from torch import nn
+
+from algoperf import init_utils
+from algoperf.workloads.dropout_modules import CustomDropout, SequentialWithDropout
+
+
+def _make_mlp(in_dim, hidden_dims, activation_fn):
+ """Creates a MLP with specified dimensions."""
+ layers = SequentialWithDropout()
+ for i, dim in enumerate(hidden_dims):
+ layers.add_module(f'dense_{i}',
+ nn.Linear(in_features=in_dim, out_features=dim))
+ layers.add_module(f'norm_{i}', nn.LayerNorm(dim, eps=1e-6))
+ layers.add_module(f'activation_fn_{i}', activation_fn())
+ layers.add_module(f'dropout_{i}', CustomDropout())
+ in_dim = dim
+ return layers
+
+
+class GNN(nn.Module):
+ """Defines a graph network.
+
+ The model assumes the input data is a jraph.GraphsTuple without global
+ variables. The final prediction will be encoded in the globals.
+ """
+
+ def __init__(self,
+ num_outputs: int = 128,
+ dropout_rate: Optional[float] = 0.1,
+ activation_fn_name: str = 'relu',
+ latent_dim: int = 256,
+ hidden_dims: Tuple[int] = (256,),
+ num_message_passing_steps: int = 5) -> None:
+ super().__init__()
+ self.latent_dim = latent_dim
+ self.hidden_dims = hidden_dims
+ self.num_message_passing_steps = num_message_passing_steps
+ self.num_outputs = num_outputs
+ if dropout_rate is None:
+ self.dropout_rate = 0.1
+ # in_features are specifically chosen for the ogbg workload.
+ self.node_embedder = nn.Linear(in_features=9, out_features=self.latent_dim)
+ self.edge_embedder = nn.Linear(in_features=3, out_features=self.latent_dim)
+
+ if activation_fn_name == 'relu':
+ activation_fn = nn.ReLU
+ elif activation_fn_name == 'gelu':
+ activation_fn = partial(nn.GELU, approximate='tanh')
+ elif activation_fn_name == 'silu':
+ activation_fn = nn.SiLU
+ else:
+ raise ValueError(
+ f'Invalid activation function name: {self.activation_fn_name}')
+
+ graph_network_layers = []
+ for st in range(self.num_message_passing_steps):
+ # Constants in in_dims are based on forward call of GraphNetwork:
+ # specifically update_edge_fn update_node_fn and update_global_fn.
+ if st == 0:
+ in_dim_edge_fn = self.latent_dim * 3 + self.num_outputs
+ in_dim_node_fn = self.latent_dim + self.hidden_dims[
+ -1] * 2 + self.num_outputs
+ last_in_dim = self.hidden_dims[-1] * 2 + self.num_outputs
+ else:
+ in_dim_edge_fn = self.hidden_dims[-1] * 4
+ in_dim_node_fn = self.hidden_dims[-1] * 4
+ last_in_dim = self.hidden_dims[-1] * 3
+
+ graph_network_layers.append(
+ GraphNetwork(
+ update_edge_fn=_make_mlp(in_dim_edge_fn,
+ self.hidden_dims,
+ activation_fn),
+ update_node_fn=_make_mlp(in_dim_node_fn,
+ self.hidden_dims,
+ activation_fn),
+ update_global_fn=_make_mlp(last_in_dim,
+ self.hidden_dims,
+ activation_fn)))
+ self.graph_network = SequentialWithDropout(*graph_network_layers)
+
+ self.decoder = nn.Linear(
+ in_features=self.hidden_dims[-1], out_features=self.num_outputs)
+
+ for m in self.modules():
+ if isinstance(m, nn.Linear):
+ init_utils.pytorch_default_init(m)
+
+ def forward(self, graph: GraphsTuple, dropout_rate=None) -> torch.Tensor:
+ if dropout_rate is None:
+ dropout_rate = self.dropout_rate
+
+ graph = graph._replace(
+ globals=torch.zeros([graph.n_node.shape[0], self.num_outputs],
+ device=graph.n_node.device))
+ graph = graph._replace(nodes=self.node_embedder(graph.nodes))
+ graph = graph._replace(edges=self.edge_embedder(graph.edges))
+
+ graph = self.graph_network(graph, dropout_rate)
+
+ # Map globals to represent the final result
+ graph = graph._replace(globals=self.decoder(graph.globals))
+
+ return graph.globals
+
+
+class GraphNetwork(nn.Module):
+ """Returns a method that applies a configured GraphNetwork.
+ This implementation follows Algorithm 1 in https://arxiv.org/abs/1806.01261
+ There is one difference. For the nodes update the class aggregates over the
+ sender edges and receiver edges separately. This is a bit more general
+ than the algorithm described in the paper. The original behaviour can be
+ recovered by using only the receiver edge aggregations for the update.
+ In addition this implementation supports softmax attention over incoming
+ edge features.
+ Example usage::
+ gn = GraphNetwork(update_edge_function,
+ update_node_function, **kwargs)
+ # Conduct multiple rounds of message passing with the same parameters:
+ for _ in range(num_message_passing_steps):
+ graph = gn(graph)
+ Args:
+ update_edge_fn: function used to update the edges or None to deactivate edge
+ updates.
+ update_node_fn: function used to update the nodes or None to deactivate node
+ updates.
+ update_global_fn: function used to update the globals or None to deactivate
+ globals updates.
+ Returns:
+ A method that applies the configured GraphNetwork.
+ """
+
+ def __init__(self,
+ update_edge_fn: Optional[Callable] = None,
+ update_node_fn: Optional[Callable] = None,
+ update_global_fn: Optional[Callable] = None) -> None:
+ super().__init__()
+ self.update_edge_fn = update_edge_fn
+ self.update_node_fn = update_node_fn
+ self.update_global_fn = update_global_fn
+ self._supports_custom_dropout = True # supports SequentialWithDropout
+
+ def forward(self, graph: GraphsTuple, dropout_rate=None) -> GraphsTuple:
+ """Applies a configured GraphNetwork to a graph.
+ This implementation follows Algorithm 1 in https://arxiv.org/abs/1806.01261
+ There is one difference. For the nodes update the class aggregates over the
+ sender edges and receiver edges separately. This is a bit more general
+ the algorithm described in the paper. The original behaviour can be
+ recovered by using only the receiver edge aggregations for the update.
+ In addition this implementation supports softmax attention over incoming
+ edge features.
+ Many popular Graph Neural Networks can be implemented as special cases of
+ GraphNets, for more information please see the paper.
+ Args:
+ graph: a `GraphsTuple` containing the graph.
+ Returns:
+ Updated `GraphsTuple`.
+ """
+ nodes, edges, receivers, senders, globals_, n_node, n_edge = graph
+ sum_n_node = tree.tree_leaves(nodes)[0].shape[0]
+ if not tree.tree_all(
+ tree.tree_map(lambda n: n.shape[0] == sum_n_node, nodes)):
+ raise ValueError(
+ 'All node arrays in nest must contain the same number of nodes.')
+
+ sent_attributes = tree.tree_map(lambda n: n[senders], nodes)
+ received_attributes = tree.tree_map(lambda n: n[receivers], nodes)
+ # Here we scatter the global features to the corresponding edges,
+ # giving us tensors of shape [num_edges, global_feat].
+ global_edge_attributes = tree.tree_map(
+ lambda g: torch.repeat_interleave(g, n_edge, dim=0), globals_)
+ if self.update_edge_fn:
+ edge_fn_inputs = torch.cat(
+ [edges, sent_attributes, received_attributes, global_edge_attributes],
+ dim=-1)
+ edges = self.update_edge_fn(edge_fn_inputs, dropout_rate)
+
+ if self.update_node_fn:
+ sent_attributes = tree.tree_map(
+ lambda e: scatter_sum(e, senders, dim=0, dim_size=sum_n_node), edges)
+ received_attributes = tree.tree_map(
+ lambda e: scatter_sum(e, receivers, dim=0, dim_size=sum_n_node),
+ edges)
+ # Here we scatter the global features to the corresponding nodes,
+ # giving us tensors of shape [num_nodes, global_feat].
+ global_attributes = tree.tree_map(
+ lambda g: torch.repeat_interleave(g, n_node, dim=0), globals_)
+ node_fn_inputs = torch.cat(
+ [nodes, sent_attributes, received_attributes, global_attributes],
+ dim=-1)
+ nodes = self.update_node_fn(node_fn_inputs, dropout_rate)
+
+ if self.update_global_fn:
+ n_graph = n_node.shape[0]
+ graph_idx = torch.arange(n_graph, device=graph.n_node.device)
+ # To aggregate nodes and edges from each graph to global features,
+ # we first construct tensors that map the node to the corresponding graph.
+ # For example, if you have `n_node=[1,2]`, we construct the tensor
+ # [0, 1, 1]. We then do the same for edges.
+ node_gr_idx = torch.repeat_interleave(graph_idx, n_node, dim=0)
+ edge_gr_idx = torch.repeat_interleave(graph_idx, n_edge, dim=0)
+ # We use the aggregation function to pool the nodes/edges per graph.
+ node_attributes = tree.tree_map(
+ lambda n: scatter_sum(n, node_gr_idx, dim=0, dim_size=n_graph), nodes)
+ edge_attributes = tree.tree_map(
+ lambda e: scatter_sum(e, edge_gr_idx, dim=0, dim_size=n_graph), edges)
+ # These pooled nodes are the inputs to the global update fn.
+ global_fn_inputs = torch.cat([node_attributes, edge_attributes, globals_],
+ dim=-1)
+ globals_ = self.update_global_fn(global_fn_inputs, dropout_rate)
+
+ return GraphsTuple(
+ nodes=nodes,
+ edges=edges,
+ receivers=receivers,
+ senders=senders,
+ globals=globals_,
+ n_node=n_node,
+ n_edge=n_edge)
+
+
+# Forked from
+# github.com/rusty1s/pytorch_scatter/blob/master/torch_scatter/scatter.py.
+def scatter_sum(src: torch.Tensor,
+ index: torch.Tensor,
+ dim: int = -1,
+ out: Optional[torch.Tensor] = None,
+ dim_size: Optional[int] = None) -> torch.Tensor:
+ r"""
+ |
+ .. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
+ master/docs/source/_figures/add.svg?sanitize=true
+ :align: center
+ :width: 400px
+ |
+ Reduces all values from the :attr:`src` tensor into :attr:`out` at the
+ indices specified in the :attr:`index` tensor along a given axis
+ :attr:`dim`.
+ For each value in :attr:`src`, its output index is specified by its index
+ in :attr:`src` for dimensions outside of :attr:`dim` and by the
+ corresponding value in :attr:`index` for dimension :attr:`dim`.
+ The applied reduction is here defined as a sum.
+ Formally, if :attr:`src` and :attr:`index` are :math:`n`-dimensional
+ tensors with size :math:`(x_0, ..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})`
+ and :attr:`dim` = `i`, then :attr:`out` must be an :math:`n`-dimensional
+ tensor with size :math:`(x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})`.
+ Moreover, the values of :attr:`index` must be between :math:`0` and
+ :math:`y - 1`, although no specific ordering of indices is required.
+ The :attr:`index` tensor supports broadcasting in case its dimensions do
+ not match with :attr:`src`.
+ For one-dimensional tensors, the operation computes
+ .. math::
+ \mathrm{out}_i = \mathrm{out}_i + \sum_j~\mathrm{src}_j
+ where :math:`\sum_j` is over :math:`j` such that
+ :math:`\mathrm{index}_j = i`.
+ .. note::
+ This operation is implemented via atomic operations on the GPU and is
+ therefore **non-deterministic** since the order of parallel operations
+ to the same value is undetermined.
+ For floating-point variables, this results in a source of variance in
+ the result.
+ :param src: The source tensor.
+ :param index: The indices of elements to scatter.
+ :param dim: The axis along which to index. (default: :obj:`-1`)
+ :param out: The destination tensor.
+ :param dim_size: If :attr:`out` is not given, automatically create output
+ with size :attr:`dim_size` at dimension :attr:`dim`.
+ If :attr:`dim_size` is not given, a minimal sized output tensor
+ according to :obj:`index.max() + 1` is returned.
+ :rtype: :class:`Tensor`
+ .. code-block:: python
+ src = torch.randn(10, 6, 64)
+ index = torch.tensor([0, 1, 0, 1, 2, 1])
+ # Broadcasting in the first and last dim.
+ out = scatter_sum(src, index, dim=1)
+ print(out.size())
+ .. code-block::
+ torch.Size([10, 3, 64])
+ """
+ index = broadcast(index, src, dim)
+ if out is None:
+ size = list(src.size())
+ if dim_size is not None:
+ size[dim] = dim_size
+ elif index.numel() == 0:
+ size[dim] = 0
+ else:
+ size[dim] = int(index.max()) + 1
+ out = torch.zeros(size, dtype=src.dtype, device=src.device)
+ return out.scatter_add_(dim, index, src)
+ else:
+ return out.scatter_add_(dim, index, src)
+
+
+# Forked from
+# github.com/rusty1s/pytorch_scatter/blob/master/torch_scatter/utils.py.
+def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int):
+ if dim < 0:
+ dim = other.dim() + dim
+ if src.dim() == 1:
+ for _ in range(0, dim):
+ src = src.unsqueeze(0)
+ for _ in range(src.dim(), other.dim()):
+ src = src.unsqueeze(-1)
+ src = src.expand(other.size())
+ return src
From e80add440b4ab281414d7babf4f5492c16d758e2 Mon Sep 17 00:00:00 2001
From: Niccolo-Ajroldi
Date: Wed, 11 Jun 2025 11:23:05 +0200
Subject: [PATCH 034/123] remove attention_dropout_rate from wmt
---
.../wmt/wmt_pytorch/models_dropout.py | 981 ++++++++++++++++++
1 file changed, 981 insertions(+)
create mode 100644 algoperf/workloads/wmt/wmt_pytorch/models_dropout.py
diff --git a/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py b/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py
new file mode 100644
index 000000000..588d06abf
--- /dev/null
+++ b/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py
@@ -0,0 +1,981 @@
+import copy
+import math
+from typing import Any, Callable, Dict, Optional, Tuple, Union
+
+import torch
+from torch import nn
+from torch import Tensor
+import torch.nn.functional as F
+from torch.nn.init import normal_
+from torch.nn.init import xavier_uniform_
+
+
+def make_causal_mask(x: Tensor, device: str = 'cuda:0') -> Tensor:
+ """Make a causal mask for self-attention.
+
+ Args:
+ x: input array of shape `[batch..., len]`
+ device: device to store the idxs
+
+ Returns:
+ A `[batch..., len, len]` shaped causal attention mask.
+ """
+ idxs = torch.broadcast_to(
+ torch.arange(x.shape[-1], dtype=torch.int32, device=device), x.shape)
+ return torch.greater_equal(idxs.unsqueeze(-1), idxs.unsqueeze(-2))
+
+
+def make_src_mask(src, inputs_segmentation, nhead):
+ """Utility for creating src mask and adjust it for PyTorch Transformer API."""
+ src_mask = torch.mul((src > 0).unsqueeze(-1), (src > 0).unsqueeze(-2))
+ # Add segmentation block-diagonal attention mask if using segmented data.
+ if inputs_segmentation is not None:
+ src_mask = torch.logical_and(
+ src_mask,
+ torch.eq(
+ inputs_segmentation.unsqueeze(-1),
+ inputs_segmentation.unsqueeze(-2)))
+ # Flip values and ensure numerical stability.
+ src_mask = torch.repeat_interleave(
+ torch.logical_not(src_mask), repeats=nhead, dim=0)
+ new_src_mask = torch.zeros_like(src_mask, dtype=torch.float32)
+ new_src_mask.masked_fill_(src_mask, -1e10)
+ return new_src_mask
+
+
+def make_tgt_and_memory_mask(tgt,
+ src,
+ inputs_segmentation,
+ targets_segmentation,
+ decode,
+ nhead):
+ """ Utility for creating target and memory mask and adjust them for PyTorch
+ Transformer API."""
+ if not decode:
+ tgt_mask = torch.logical_and(
+ torch.mul((tgt > 0).unsqueeze(-1), (tgt > 0).unsqueeze(-2)),
+ make_causal_mask(tgt, device=tgt.device))
+ memory_mask = torch.mul((tgt > 0).unsqueeze(-1), (src > 0).unsqueeze(-2))
+ else:
+ tgt_mask = None
+ memory_mask = torch.mul((torch.ones_like(tgt) > 0).unsqueeze(-1),
+ (src > 0).unsqueeze(-2))
+ # Add segmentation block-diagonal attention masks if using segmented data.
+ if inputs_segmentation is not None:
+ tgt_mask = torch.logical_and(
+ tgt_mask,
+ torch.eq(
+ targets_segmentation.unsqueeze(-1),
+ targets_segmentation.unsqueeze(-2)))
+ memory_mask = torch.logical_and(
+ memory_mask,
+ torch.eq(
+ targets_segmentation.unsqueeze(-1),
+ inputs_segmentation.unsqueeze(-2)))
+ # Flip values and ensure numerical stability.
+ memory_mask = torch.repeat_interleave(
+ torch.logical_not(memory_mask), repeats=nhead, dim=0)
+ new_memory_mask = torch.zeros_like(memory_mask, dtype=torch.float32)
+ new_memory_mask.masked_fill_(memory_mask, -1e10)
+ if tgt_mask is not None:
+ tgt_mask = torch.repeat_interleave(
+ torch.logical_not(tgt_mask), repeats=nhead, dim=0)
+ new_tgt_mask = torch.zeros_like(tgt_mask, dtype=torch.float32)
+ new_tgt_mask.masked_fill_(tgt_mask, -1e10)
+ tgt_mask = new_tgt_mask
+ return tgt_mask, new_memory_mask
+
+
+def shift_right(x, axis=1):
+ """Shift the input to the right by padding on axis 1."""
+ pad_widths = [(0, 0)] * len(x.shape)
+ pad_widths[axis] = (1, 0)
+ pad_widths = tuple(t for tup in reversed(pad_widths) for t in tup)
+ padded = F.pad(x, pad_widths, mode='constant')
+ return padded[:, :-1]
+
+
+class Transformer(nn.Module):
+ """Transformer architecture based on the model from the WMT Jax workload."""
+
+ def __init__(self,
+ ntoken: int = 32000,
+ d_model: int = 1024,
+ nhead: int = 16,
+ d_hid: int = 1024,
+ nlayers: int = 6,
+ dropout_rate: Optional[float] = 0.1,
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
+ glu: bool = False,
+ layer_norm_eps: float = 1e-6,
+ attention_temp: float = 1.0,
+ pre_ln: bool = True):
+ super().__init__()
+ if dropout_rate is None:
+ dropout_rate = 0.1
+ self.pos_encoder = PositionalEncoding(d_model, dropout_rate)
+ self.shared_embedding = nn.Embedding(ntoken, d_model)
+ self.encoder = Encoder(d_model,
+ nhead,
+ d_hid,
+ nlayers,
+ dropout_rate,
+ activation,
+ glu,
+ layer_norm_eps,
+ attention_temp,
+ pre_ln)
+ self.decoder = Decoder(d_model,
+ nhead,
+ d_hid,
+ nlayers,
+ dropout_rate,
+ activation,
+ glu,
+ layer_norm_eps,
+ attention_temp,
+ pre_ln)
+ # Share positional encoding and embedding between encoder and decoder.
+ self.encoder.pos_encoder = self.pos_encoder
+ self.encoder.shared_embedding = self.shared_embedding
+ self.decoder.pos_encoder = self.pos_encoder
+ self.decoder.shared_embedding = self.shared_embedding
+
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ """Initiate parameters in the transformer model."""
+ for module in self.modules():
+ if isinstance(module, nn.Linear):
+ xavier_uniform_(module.weight)
+ if module.bias is not None:
+ normal_(module.bias, std=1e-6)
+
+ def forward(self,
+ src: Tensor,
+ tgt: Tensor,
+ inputs_positions: Optional[Tensor] = None,
+ targets_positions: Optional[Tensor] = None,
+ inputs_segmentation: Optional[Tensor] = None,
+ targets_segmentation: Optional[Tensor] = None,
+ decode: bool = False) -> Tensor:
+ """
+ Args:
+ src: Tensor, shape [batch_size, seq_len]
+ tgt: Tensor, shape [batch_size, seq_len]
+ inputs_positions: Optional[Tensor], shape [batch_size, seq_len]
+ targets_positions: Optional[Tensor], shape [batch_size, seq_len]
+ inputs_segmentation: Optional[Tensor], shape [batch_size, seq_len]
+ targets_segmentation: Optional[Tensor], shape [batch_size, seq_len]
+ decode: bool
+
+ Returns:
+ output Tensor of shape [batch_size, seq_len, ntoken]
+ """
+ if src.size(0) != tgt.size(0):
+ raise RuntimeError('The batch size of src and tgt must be equal.')
+ memory = self.encoder(
+ src,
+ inputs_positions=inputs_positions,
+ inputs_segmentation=inputs_segmentation)
+ output = self.decoder(
+ tgt,
+ memory,
+ src, # just for calculating the padding mask
+ targets_positions=targets_positions,
+ inputs_segmentation=inputs_segmentation,
+ targets_segmentation=targets_segmentation,
+ decode=decode)
+ return output
+
+
+class TransformerEncoder(nn.Module):
+ r"""TransformerEncoder is a stack of N encoder layers. Users can build the
+ BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
+
+ Args:
+ encoder_layer: an instance of the TransformerEncoderLayer() class.
+ num_layers: the number of sub-encoder-layers in the encoder.
+ norm: the layer normalization component (optional).
+ enable_nested_tensor: if True, input will automatically convert to
+ nested tensor (and convert back on output). This will improve
+ the overall performance of TransformerEncoder when padding
+ rate is high.
+
+ Examples::
+ >>> encoder_layer = nn.TransformerEncoderLayer(12, 8)
+ >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, 6)
+ >>> src = torch.rand(10, 32, 512)
+ >>> out = transformer_encoder(src)
+ """
+ __constants__ = ['norm']
+
+ def __init__(self,
+ encoder_layer,
+ num_layers,
+ norm=None,
+ enable_nested_tensor=True,
+ mask_check=True):
+ super().__init__()
+ self.layers = nn.ModuleList(
+ [copy.deepcopy(encoder_layer) for _ in range(num_layers)])
+ self.num_layers = num_layers
+ self.norm = norm
+ self.enable_nested_tensor = enable_nested_tensor
+ self.mask_check = mask_check
+
+ def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor:
+ """Pass the input through the encoder layers in turn.
+
+ Args:
+ src: the sequence to the encoder (required).
+ mask: the mask for the src sequence (optional).
+
+ Shape:
+ see the docs in Transformer class.
+ """
+ output = src
+ convert_to_nested = False
+
+ for mod in self.layers:
+ output = mod(output, src_mask=mask)
+
+ if convert_to_nested:
+ output = output.to_padded_tensor(0.)
+
+ if self.norm is not None:
+ output = self.norm(output)
+
+ return output
+
+
+class Encoder(nn.Module):
+
+ def __init__(self,
+ d_model: int = 1024,
+ nhead: int = 16,
+ d_hid: int = 1024,
+ nlayers: int = 6,
+ dropout_rate: float = 0.1,
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
+ glu: bool = False,
+ layer_norm_eps: float = 1e-6,
+ attention_temp: float = 1.0,
+ pre_ln: bool = True):
+ super().__init__()
+ self.nhead = nhead
+ self.shared_embedding = None
+ self.pos_encoder = None
+ encoder_layer = TransformerEncoderLayer(
+ d_model,
+ nhead,
+ d_hid,
+ dropout_rate,
+ activation=activation,
+ glu=glu,
+ layer_norm_eps=layer_norm_eps,
+ attention_temp=attention_temp,
+ pre_ln=pre_ln)
+ encoder_norm = (
+ nn.LayerNorm(d_model, eps=layer_norm_eps) if pre_ln else None)
+ self.encoder = TransformerEncoder(encoder_layer, nlayers, encoder_norm)
+
+ def forward(self,
+ src: Tensor,
+ inputs_positions: Optional[Tensor] = None,
+ inputs_segmentation: Optional[Tensor] = None) -> Tensor:
+ src = src.to(torch.int)
+ src_mask = make_src_mask(src, inputs_segmentation, self.nhead)
+ src = self.shared_embedding(src)
+ src = self.pos_encoder(src, inputs_positions)
+ memory = self.encoder(src, mask=src_mask)
+ return memory
+
+
+class Decoder(nn.Module):
+
+ def __init__(self,
+ d_model: int = 1024,
+ nhead: int = 16,
+ d_hid: int = 1024,
+ nlayers: int = 6,
+ dropout_rate: float = 0.1,
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
+ glu: bool = False,
+ layer_norm_eps: float = 1e-6,
+ attention_temp: float = 1.0,
+ pre_ln: bool = True):
+ super().__init__()
+ self.nhead = nhead
+ self.shared_embedding = None
+ self.pos_encoder = None
+ self.decoder = TransformerDecoder(d_model,
+ nhead,
+ d_hid,
+ dropout_rate,
+ activation,
+ glu,
+ layer_norm_eps,
+ nlayers,
+ attention_temp,
+ pre_ln)
+
+ def forward(
+ self,
+ tgt: Tensor,
+ memory: Tensor,
+ src: Tensor, # just for calculating the padding mask
+ targets_positions: Optional[Tensor] = None,
+ inputs_segmentation: Optional[Tensor] = None,
+ targets_segmentation: Optional[Tensor] = None,
+ decode: bool = False,
+ max_len: Optional[int] = None,
+ cache: Optional[dict] = None) -> Any:
+ tgt = tgt.to(torch.int)
+ tgt_mask, memory_mask = make_tgt_and_memory_mask(
+ tgt, src, inputs_segmentation, targets_segmentation,
+ decode, self.nhead)
+ if not decode:
+ tgt = shift_right(tgt)
+ tgt = self.shared_embedding(tgt)
+ tgt = self.pos_encoder(tgt, targets_positions, decode=decode, cache=cache)
+ if decode:
+ tgt, cache = tgt
+ output = self.decoder(
+ tgt,
+ memory,
+ tgt_mask=tgt_mask,
+ memory_mask=memory_mask,
+ decode=decode,
+ max_len=max_len,
+ cache=cache)
+ if decode:
+ output, cache = output
+ normalize = math.sqrt(output.shape[-1])
+ output = torch.matmul(output, self.shared_embedding.weight.T) / normalize
+ if decode:
+ return output, cache
+ return output
+
+
+class PositionalEncoding(nn.Module):
+
+ def __init__(self,
+ d_model: int,
+ dropout_rate: float = 0.1,
+ max_len: int = 256):
+ super().__init__()
+ self.dropout = nn.Dropout(p=dropout_rate)
+
+ position = torch.arange(max_len).unsqueeze(1)
+ scale_factor = -math.log(10000.0) / (d_model // 2 - 1)
+ div_term = torch.exp(torch.arange(d_model // 2) * scale_factor)
+ pe = torch.zeros(1, max_len, d_model)
+ pe[0, :, :d_model // 2] = torch.sin(position * div_term)
+ pe[0, :, d_model // 2:2 * (d_model // 2)] = torch.cos(position * div_term)
+ self.register_buffer('pe', pe)
+
+ def forward(
+ self,
+ x: Tensor,
+ inputs_positions: Optional[Tensor] = None,
+ decode: bool = False,
+ cache: Optional[Dict[str, Dict[str, Tensor]]] = None
+ ) -> Union[Tensor, Tuple[Tensor, Dict[str, Dict[str, Tensor]]]]:
+ """
+ Args:
+ x: Tensor (shape [batch_size, seq_len, embedding_dim])
+ inputs_positions: Tensor (shape [batch_size, seq_len]) or None
+ decode: bool
+ cache: Dict[str, Dict[str, Tensor]] or None
+ Returns:
+ Tensor or Tuple[Tensor, Dict[str, Dict[str, Tensor]]]
+ """
+ # We use a cache position index for tracking decoding position.
+ if decode:
+ name = self._get_name()
+ if cache is None:
+ cache = {
+ name: {
+ 'cache_index':
+ torch.tensor(0, dtype=torch.long, device=self.pe.device),
+ },
+ }
+ pe = self.pe[0, cache[name]['cache_index'], :]
+ cache[name]['cache_index'] += 1
+ return self.dropout(x + pe), cache
+ if inputs_positions is None:
+ # normal unpacked case:
+ pe = self.pe[:, :x.size(1), :]
+ else:
+ # for packed data we need to use known position indices:
+ pe = self.pe[0, inputs_positions, :]
+ return self.dropout(x + pe)
+
+
+# TransformerEncoderLayer and TransformerDecoderLayer are taken from:
+# https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/transformer.py
+# Main difference is the use of custom MultiheadAttention modules.
+class TransformerEncoderLayer(nn.Module):
+ r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
+ This standard encoder layer is based on the paper "Attention Is All You Need".
+ Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones,
+ Aidan N Gomez, Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all
+ you need. In Advances in Neural Information Processing Systems,
+ pages 6000-6010. Users may modify or implement in a different way during
+ application.
+ Args:
+ d_model: the number of expected features in the input (default=1024).
+ nhead: the number of heads in the multiheadattention models (default=16).
+ dim_feedforward: the dimension of the feedforward network model
+ (default=1024).
+ dropout_rate: the dropout_rate value (default=0.1).
+ activation: the activation function of the intermediate layer, can be a
+ string ("relu" or "gelu") or a unary callable (default=F.relu).
+ layer_norm_eps: the eps value in layer normalization components
+ (default=1e-6).
+ pre_ln: if ``True``, layer norm is done prior to attention and
+ feedforward operations, respectivaly. Otherwise it's done after.
+ Default: ``True``.
+ Examples::
+ >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
+ >>> src = torch.rand(32, 10, 512)
+ >>> out = encoder_layer(src)
+ """
+ __constants__ = ['pre_ln']
+
+ def __init__(self,
+ d_model: int = 1024,
+ nhead: int = 16,
+ dim_feedforward: int = 1024,
+ dropout_rate: float = 0.1,
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
+ glu: bool = False,
+ layer_norm_eps: float = 1e-6,
+ attention_temp: float = 1.0,
+ pre_ln: bool = True,
+ device=None,
+ dtype=None) -> None:
+ factory_kwargs = {'device': device, 'dtype': dtype}
+ super().__init__()
+ self.self_attn = MultiheadAttention(
+ d_model,
+ nhead,
+ self_attn=True,
+ dropout_rate=dropout_rate,
+ attention_temp=attention_temp,
+ bias=False,
+ **factory_kwargs)
+
+ # Implementation of Feedforward model.
+ self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
+ self.glu = glu
+ if self.glu:
+ self.linear_glu = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
+ self.dropout = nn.Dropout(dropout_rate)
+ self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
+
+ self.pre_ln = pre_ln
+ self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
+ self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
+ self.dropout1 = nn.Dropout(dropout_rate)
+ self.dropout2 = nn.Dropout(dropout_rate)
+
+ self.activation = activation
+
+ def forward(self, src: Tensor, src_mask: Optional[Tensor] = None) -> Tensor:
+ r"""Pass the input through the encoder layer.
+
+ Args:
+ src: the sequence to the encoder layer (required).
+ src_mask: the mask for the src sequence (optional).
+
+ Shape:
+ see the docs in Transformer class.
+ """
+ x = src
+ if self.pre_ln:
+ x = x + self._sa_block(self.norm1(x), src_mask)
+ x = x + self._ff_block(self.norm2(x))
+ else:
+ x = self.norm1(x + self._sa_block(x, src_mask))
+ x = self.norm2(x + self._ff_block(x))
+
+ return x
+
+ # Self-attention block:
+ def _sa_block(self, x: Tensor, attn_mask: Optional[Tensor]) -> Tensor:
+ x, _ = self.self_attn(x, attn_mask=attn_mask)
+ return self.dropout1(x)
+
+ # Feed forward block:
+ def _ff_block(self, inputs: Tensor) -> Tensor:
+ x = self.activation(self.linear1(inputs))
+ if self.glu:
+ y = self.linear_glu(inputs)
+ x = x * y
+ x = self.linear2(self.dropout(x))
+ return self.dropout2(x)
+
+
+# Modified to use cache for autoregressive decoding and custom
+# MultiheadAttention modules.
+class TransformerDecoder(nn.Module):
+ r"""TransformerDecoder is a stack of N decoder layers
+ Args:
+ d_model: the number of expected features in the input (default=1024)
+ nhead: the number of heads in the multiheadattention models (default=16)
+ d_hid: the dimension of the feedforward network model
+ (default=1024)
+ dropout_rate: the dropout_rate value (default=0.1)
+ layer_norm_eps: the eps value in layer normalization components
+ (default=1e-6).
+ decoder_layer: an instance of the TransformerDecoderLayer() class
+ num_layers: the number of sub-decoder-layers in the decoder
+ Examples::
+ >>> decoder_layer = nn.TransformerDecoderLayer(12, 8)
+ >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, 6)
+ >>> memory = torch.rand(10, 32, 512)
+ >>> tgt = torch.rand(20, 32, 512)
+ >>> out = transformer_decoder(tgt, memory)
+ """
+ __constants__ = ['norm']
+
+ def __init__(self,
+ d_model,
+ nhead,
+ d_hid,
+ dropout_rate,
+ activation,
+ glu,
+ layer_norm_eps,
+ num_layers,
+ attention_temp,
+ pre_ln):
+ super().__init__()
+ self.layers = nn.ModuleList([
+ TransformerDecoderLayer(
+ d_model,
+ nhead,
+ d_hid,
+ dropout_rate,
+ activation,
+ glu,
+ layer_norm_eps=layer_norm_eps,
+ attention_temp=attention_temp,
+ pre_ln=pre_ln) for _ in range(num_layers)
+ ])
+ self.num_layers = num_layers
+ self.norm = (nn.LayerNorm(d_model, eps=layer_norm_eps) if pre_ln else None)
+
+ def forward(self,
+ tgt: Tensor,
+ memory: Tensor,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ decode: bool = False,
+ max_len: Optional[int] = None,
+ cache: Optional[dict] = None) -> Any:
+ r"""Pass the inputs (and mask) through the decoder layer in turn.
+ Args:
+ tgt: the sequence to the decoder (required).
+ memory: the sequence from the last layer of the encoder (required).
+ tgt_mask: the mask for the tgt sequence (optional).
+ memory_mask: the mask for the memory sequence (optional).
+ decode: whether to use cache for autoregressive decoding or not.
+ max_len: maximum sequence length, necessary for decoding cache.
+ Shape:
+ see the docs in Transformer class.
+ """
+ output = tgt
+
+ for idx, mod in enumerate(self.layers):
+ output, cache = mod(
+ output,
+ memory,
+ tgt_mask=tgt_mask,
+ memory_mask=memory_mask,
+ decode=decode,
+ max_len=max_len,
+ cache=cache,
+ index=idx)
+
+ if self.norm is not None:
+ output = self.norm(output)
+
+ if decode:
+ return output, cache
+ return output
+
+
+# Modified to use cache for autoregressive decoding and custom
+# MultiheadAttention modules.
+class TransformerDecoderLayer(nn.Module):
+ r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and
+ feedforward network.
+ This standard decoder layer is based on the paper "Attention Is All You Need".
+ Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones,
+ Aidan N Gomez, Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all
+ you need. In Advances in Neural Information Processing Systems,
+ pages 6000-6010. Users may modify or implement in a different way during
+ application.
+ Args:
+ d_model: the number of expected features in the input (default=1024).
+ nhead: the number of heads in the multiheadattention models (default=16).
+ dim_feedforward: the dimension of the feedforward network model
+ (default=1024).
+ dropout_rate: the dropout_rate value (default=0.1).
+ activation: the activation function of the intermediate layer, can be a
+ string ("relu" or "gelu") or a unary callable (default=F.relu).
+ layer_norm_eps: the eps value in layer normalization components
+ (default=1e-6).
+ pre_ln: if ``True``, layer norm is done prior to self attention,
+ multihead attention and feedforward operations, respectivaly.
+ Otherwise it's done after. Default: ``True``.
+ Examples::
+ >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
+ >>> memory = torch.rand(32, 10, 512)
+ >>> tgt = torch.rand(32, 20, 512)
+ >>> out = decoder_layer(tgt, memory)
+ """
+ __constants__ = ['pre_ln']
+
+ def __init__(self,
+ d_model: int = 1024,
+ nhead: int = 16,
+ dim_feedforward: int = 1024,
+ dropout_rate: float = 0.1,
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
+ glu: bool = False,
+ layer_norm_eps: float = 1e-6,
+ pre_ln: bool = True,
+ attention_temp: float = 1.0,
+ device=None,
+ dtype=None) -> None:
+ factory_kwargs = {'device': device, 'dtype': dtype}
+ super().__init__()
+ self.self_attn = MultiheadAttention(
+ d_model,
+ nhead,
+ self_attn=True,
+ dropout_rate=dropout_rate,
+ attention_temp=attention_temp,
+ bias=False,
+ **factory_kwargs)
+ self.multihead_attn = MultiheadAttention(
+ d_model,
+ nhead,
+ self_attn=False,
+ dropout_rate=dropout_rate,
+ attention_temp=attention_temp,
+ bias=False,
+ **factory_kwargs)
+
+ # Implementation of Feedforward model.
+ self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
+ self.glu = glu
+ if self.glu:
+ self.linear_glu = nn.Linear(dim_feedforward,
+ dim_feedforward,
+ **factory_kwargs)
+ self.dropout = nn.Dropout(dropout_rate)
+ self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
+
+ self.pre_ln = pre_ln
+ self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
+ self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
+ self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
+ self.dropout1 = nn.Dropout(dropout_rate)
+ self.dropout2 = nn.Dropout(dropout_rate)
+ self.dropout3 = nn.Dropout(dropout_rate)
+
+ self.activation = activation
+
+ def forward( # pylint: disable=arguments-renamed
+ self,
+ tgt: Tensor,
+ memory: Tensor,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ decode: bool = False,
+ max_len: Optional[int] = None,
+ cache: Optional[dict] = None,
+ index: Optional[int] = None) -> Any:
+ r"""Pass the inputs (and mask) through the decoder layer.
+ Args:
+ tgt: the sequence to the decoder layer (required).
+ memory: the sequence from the last layer of the encoder (required).
+ tgt_mask: the mask for the tgt sequence (optional).
+ memory_mask: the mask for the memory sequence (optional).
+ decode: wether to use cache for autoregressive decoding or not.
+ max_len: maximum sequence length, necessary for decoding cache.
+ Shape:
+ see the docs in Transformer class.
+ """
+ # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
+
+ x = tgt
+ if self.pre_ln:
+ sa_out, cache = self._sa_block(
+ self.norm1(x),
+ tgt_mask,
+ decode=decode,
+ max_len=max_len,
+ cache=cache,
+ index=index)
+ x = x + sa_out
+ x = x + self._mha_block(self.norm2(x), memory, memory_mask)
+ x = x + self._ff_block(self.norm3(x))
+ else:
+ sa_out, cache = self._sa_block(
+ x,
+ tgt_mask,
+ decode=decode,
+ max_len=max_len,
+ cache=cache,
+ index=index)
+ x = self.norm1(x + sa_out)
+ x = self.norm2(x + self._mha_block(x, memory, memory_mask))
+ x = self.norm3(x + self._ff_block(x))
+
+ return x, cache
+
+ # Self-attention block:
+ def _sa_block( # pylint: disable=arguments-renamed
+ self,
+ x: Tensor,
+ attn_mask: Optional[Tensor],
+ decode: bool = False,
+ max_len: Optional[int] = None,
+ cache: Optional[dict] = None,
+ index: Optional[int] = None) -> Any:
+ x, cache = self.self_attn(
+ x,
+ attn_mask=attn_mask,
+ decode=decode,
+ max_len=max_len,
+ cache=cache,
+ index=index)
+ return self.dropout1(x), cache
+
+ # Multihead attention block:
+ def _mha_block(self, x: Tensor, mem: Tensor,
+ attn_mask: Optional[Tensor]) -> Tensor:
+ x, _ = self.multihead_attn(x, mem, attn_mask=attn_mask)
+ return self.dropout2(x)
+
+ # Feed forward block.
+ def _ff_block(self, inputs: Tensor) -> Tensor:
+ x = self.activation(self.linear1(inputs))
+ if self.glu:
+ y = self.linear_glu(inputs)
+ x = x * y
+ x = self.linear2(self.dropout(x))
+ return self.dropout3(x)
+
+
+class MultiheadAttention(nn.Module):
+ r"""Allows the model to jointly attend to information
+ from different representation subspaces. Supports self-attention and
+ encoder-decoder attention.
+ See `Attention Is All You Need `_.
+ .. math::
+ \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
+ where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
+ Args:
+ embed_dim: Total dimension of the model.
+ num_heads: Number of parallel attention heads. Note that ``embed_dim`` will
+ be split across ``num_heads`` (i.e. each head will have dimension
+ ``embed_dim // num_heads``).
+ self_attn: Whether self attention or encoder-decoder attention is used.
+ Default: ``True``.
+ dropout_rate: Dropout probability on ``attn_output_weights``.
+ Default: ``0.0`` (no dropout_rate).
+ bias: If specified, adds bias to input / output projection layers.
+ Default: ``False``.
+ device: The device of the module.
+ dtype: The dtype of the module.
+ Examples::
+ >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
+ >>> attn_output, cache = multihead_attn(x)
+ """
+
+ def __init__(self,
+ embed_dim: int,
+ num_heads: int,
+ self_attn: bool = True,
+ dropout_rate: float = 0.,
+ attention_temp: float = 1.0,
+ bias: bool = False,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None) -> None:
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.self_attn = self_attn
+ self.dropout = dropout_rate
+ self.head_dim = embed_dim // num_heads
+ self.attention_temp = attention_temp
+ assert self.head_dim * num_heads == self.embed_dim, \
+ 'embed_dim must be divisible by num_heads.'
+
+ factory_kwargs = {'device': device, 'dtype': dtype}
+ if self_attn:
+ # Self-attention.
+ self.in_proj = nn.Linear(
+ embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
+ else:
+ # Encoder-decoder attention.
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
+ self.kv_proj = nn.Linear(
+ embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs)
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
+
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ """Initiate parameters in the MultiheadAttention module."""
+ for module in self.modules():
+ if isinstance(module, nn.Linear):
+ xavier_uniform_(module.weight)
+ if module.bias is not None:
+ normal_(module.bias, std=1e-6)
+
+ def forward(self,
+ x: Tensor,
+ mem: Optional[Tensor] = None,
+ attn_mask: Optional[Tensor] = None,
+ decode: bool = False,
+ max_len: Optional[int] = None,
+ cache: Optional[dict] = None,
+ index: Optional[int] = None) -> Any:
+ r"""
+ Args:
+ x: Batch of input sequences of shape
+ (batch size, sequence length, embedding dimensionality) for self
+ attention mechanism. See "Attention Is All You Need" for more details.
+ mem: Batch of input sequences of shape
+ (batch size, sequence length, embedding dimensionality) for
+ encoder-decoder attention. See "Attention Is All You Need" for more
+ details.
+ attn_mask: If specified, a 2D or 3D mask preventing attention to certain
+ positions. Must be of shape :math:`(L, S)` or
+ :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the
+ batch size, :math:`L` is the target sequence length, and :math:`S`
+ is the source sequence length. A 2D mask will be broadcasted across
+ the batch while a 3D mask allows for a different mask for each entry
+ in the batch. Binary, byte, and float masks are supported.
+ For a binary mask, a ``True`` value indicates that the
+ corresponding position is not allowed to attend. For a byte mask,
+ a non-zero value indicates that the corresponding position is not
+ allowed to attend. For a float mask, the mask values will be added to
+ the attention weight.
+ decode: wether to use cache for autoregressive decoding or not.
+ max_len: maximum sequence length, necessary for decoding cache.
+ cache: cache dictionary for autoregressive decoding.
+ index: index of the current decoding step, necessary for decoding cache.
+ Outputs:
+ - **attn_output** - Attention outputs of shape :math:`(N, L, E)`, where
+ :math:`L` is the target sequence length, :math:`N` is the batch size,
+ and :math:`E` is the embedding dimension ``embed_dim``.
+ - **cache** - For autoregressive decoding.
+ """
+ # Shape: (batch size, sequence length, embedding dimensionality)
+ bsz, seq_len, embed_dim = x.size()
+ # In projection.
+ if self.self_attn:
+ q, k, v = self.in_proj(x).split(self.embed_dim, dim=2)
+ else:
+ q = self.q_proj(x)
+ k, v = self.kv_proj(mem).split(self.embed_dim, dim=2)
+ # This is 1 (!= seq_len) during autoreregressive decoding.
+ tgt_len = q.size(1)
+
+ # During fast autoregressive decoding, we feed one position at a time,
+ # and cache the keys and values step by step.
+ name = f'decoder.layers.{index}.self_attn'
+ loc_cache = cache[name] if decode and name in cache else None
+ if decode:
+ if loc_cache is None:
+ loc_cache = {
+ 'cached_key':
+ torch.zeros((bsz, max_len, embed_dim),
+ dtype=k.dtype,
+ device=k.device),
+ 'cached_value':
+ torch.zeros((bsz, max_len, embed_dim),
+ dtype=v.dtype,
+ device=v.device),
+ 'cache_index':
+ torch.tensor(0, dtype=torch.long, device=k.device),
+ }
+ cached_key = loc_cache['cached_key']
+ cached_value = loc_cache['cached_value']
+ cache_index = loc_cache['cache_index']
+ # Shape check of cached keys against query input.
+ expected_shape = (bsz, 1, embed_dim)
+ if expected_shape != x.shape:
+ raise ValueError('Autoregressive cache shape error, expected query '
+ f'shape {expected_shape} instead got {x.shape}.')
+ # Update key, value caches with our new 1d spatial slices.
+ cached_key[:, cache_index:cache_index + 1, :] = k
+ cached_value[:, cache_index:cache_index + 1, :] = v
+ k = cached_key
+ v = cached_value
+ cache_index += 1
+ # Causal mask for cached decoder self-attention:
+ # our single query position should only attend to those key
+ # positions that have already been generated and cached,
+ # not the remaining zero elements.
+ if attn_mask is not None:
+ raise ValueError('Attention mask has to be None for decode == True.')
+ attn_mask = (torch.arange(max_len, device=k.device) >=
+ cache_index).reshape(1, max_len)
+
+ # Update sequence length to account for complete sequence.
+ seq_len = k.size(1)
+
+ # Rearrange q, k, v for multihead attention.
+ q = q.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
+ k = k.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
+ v = v.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
+
+ # Check dtype and shape of attention mask.
+ if not decode and attn_mask is not None:
+ assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, \
+ f'Float and bool dtypes are supported, not {attn_mask.dtype}.'
+ # Ensure attn_mask's dim is 3.
+ if attn_mask.dim() == 3:
+ correct_3d_size = (bsz * self.num_heads, tgt_len, seq_len)
+ if attn_mask.shape != correct_3d_size:
+ raise RuntimeError(f'The shape of attn_mask is {attn_mask.shape}, '
+ f'but should be {correct_3d_size}.')
+ else:
+ raise RuntimeError(
+ f"attn_mask's dimension {attn_mask.dim()} is not supported")
+ # Reshape attention mask to be consistent with q, k, v.
+ attn_mask = attn_mask.view(bsz, self.num_heads, tgt_len, seq_len)
+
+ # Convert attention mask to float.
+ if attn_mask is not None and attn_mask.dtype == torch.bool:
+ new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
+ new_attn_mask.masked_fill_(attn_mask, -1e10)
+ attn_mask = new_attn_mask
+
+ # Adjust dropout_rate probability.
+ dropout_rate = self.dropout if self.training else 0.0
+
+ # Calculate attention.
+ q = self.attention_temp * q
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ q, k, v, attn_mask, dropout_rate)
+ # Rearrange for output projection.
+ attn_output = attn_output.transpose(1, 2).contiguous().view(
+ bsz, tgt_len, embed_dim)
+ # Output projection.
+ attn_output = self.out_proj(attn_output)
+
+ if decode:
+ cache[name] = loc_cache
+
+ return attn_output, cache
From 84b1bd19bb1083947adfa87a9488b074a6b170ac Mon Sep 17 00:00:00 2001
From: Niccolo-Ajroldi
Date: Wed, 11 Jun 2025 15:37:33 +0200
Subject: [PATCH 035/123] dropout fix on wmt
---
.../wmt/wmt_pytorch/models_dropout.py | 168 ++++++++++--------
1 file changed, 91 insertions(+), 77 deletions(-)
diff --git a/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py b/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py
index 588d06abf..c5014d87d 100644
--- a/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py
+++ b/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py
@@ -112,14 +112,15 @@ def __init__(self,
pre_ln: bool = True):
super().__init__()
if dropout_rate is None:
- dropout_rate = 0.1
- self.pos_encoder = PositionalEncoding(d_model, dropout_rate)
+ self.dropout_rate = 0.1
+ else:
+ self.dropout_rate = dropout_rate
+ self.pos_encoder = PositionalEncoding(d_model)
self.shared_embedding = nn.Embedding(ntoken, d_model)
self.encoder = Encoder(d_model,
nhead,
d_hid,
nlayers,
- dropout_rate,
activation,
glu,
layer_norm_eps,
@@ -129,7 +130,6 @@ def __init__(self,
nhead,
d_hid,
nlayers,
- dropout_rate,
activation,
glu,
layer_norm_eps,
@@ -158,7 +158,8 @@ def forward(self,
targets_positions: Optional[Tensor] = None,
inputs_segmentation: Optional[Tensor] = None,
targets_segmentation: Optional[Tensor] = None,
- decode: bool = False) -> Tensor:
+ decode: bool = False,
+ dropout_rate: Optional[float] = None) -> Tensor:
"""
Args:
src: Tensor, shape [batch_size, seq_len]
@@ -168,16 +169,22 @@ def forward(self,
inputs_segmentation: Optional[Tensor], shape [batch_size, seq_len]
targets_segmentation: Optional[Tensor], shape [batch_size, seq_len]
decode: bool
+ dropout_rate: Optional[float]
Returns:
output Tensor of shape [batch_size, seq_len, ntoken]
"""
if src.size(0) != tgt.size(0):
raise RuntimeError('The batch size of src and tgt must be equal.')
+
+ if dropout_rate is None:
+ dropout_rate = self.dropout_rate
+
memory = self.encoder(
src,
inputs_positions=inputs_positions,
- inputs_segmentation=inputs_segmentation)
+ inputs_segmentation=inputs_segmentation,
+ dropout_rate=dropout_rate)
output = self.decoder(
tgt,
memory,
@@ -185,7 +192,8 @@ def forward(self,
targets_positions=targets_positions,
inputs_segmentation=inputs_segmentation,
targets_segmentation=targets_segmentation,
- decode=decode)
+ decode=decode,
+ dropout_rate=dropout_rate)
return output
@@ -224,12 +232,15 @@ def __init__(self,
self.enable_nested_tensor = enable_nested_tensor
self.mask_check = mask_check
- def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor:
+ def forward(self, src: Tensor,
+ mask: Optional[Tensor] = None,
+ dropout_rate: Optional[float] = None) -> Tensor:
"""Pass the input through the encoder layers in turn.
Args:
src: the sequence to the encoder (required).
mask: the mask for the src sequence (optional).
+ dropout_rate: the dropout probability (optional)
Shape:
see the docs in Transformer class.
@@ -238,7 +249,7 @@ def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor:
convert_to_nested = False
for mod in self.layers:
- output = mod(output, src_mask=mask)
+ output = mod(output, src_mask=mask, dropout_rate=dropout_rate)
if convert_to_nested:
output = output.to_padded_tensor(0.)
@@ -256,7 +267,6 @@ def __init__(self,
nhead: int = 16,
d_hid: int = 1024,
nlayers: int = 6,
- dropout_rate: float = 0.1,
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
glu: bool = False,
layer_norm_eps: float = 1e-6,
@@ -270,7 +280,6 @@ def __init__(self,
d_model,
nhead,
d_hid,
- dropout_rate,
activation=activation,
glu=glu,
layer_norm_eps=layer_norm_eps,
@@ -283,12 +292,13 @@ def __init__(self,
def forward(self,
src: Tensor,
inputs_positions: Optional[Tensor] = None,
- inputs_segmentation: Optional[Tensor] = None) -> Tensor:
+ inputs_segmentation: Optional[Tensor] = None,
+ dropout_rate: Optional[float] = None) -> Tensor:
src = src.to(torch.int)
src_mask = make_src_mask(src, inputs_segmentation, self.nhead)
src = self.shared_embedding(src)
- src = self.pos_encoder(src, inputs_positions)
- memory = self.encoder(src, mask=src_mask)
+ src = self.pos_encoder(src, inputs_positions, dropout_rate=dropout_rate)
+ memory = self.encoder(src, mask=src_mask, dropout_rate=dropout_rate)
return memory
@@ -299,7 +309,6 @@ def __init__(self,
nhead: int = 16,
d_hid: int = 1024,
nlayers: int = 6,
- dropout_rate: float = 0.1,
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
glu: bool = False,
layer_norm_eps: float = 1e-6,
@@ -312,7 +321,6 @@ def __init__(self,
self.decoder = TransformerDecoder(d_model,
nhead,
d_hid,
- dropout_rate,
activation,
glu,
layer_norm_eps,
@@ -330,7 +338,8 @@ def forward(
targets_segmentation: Optional[Tensor] = None,
decode: bool = False,
max_len: Optional[int] = None,
- cache: Optional[dict] = None) -> Any:
+ cache: Optional[dict] = None,
+ dropout_rate: Optional[float] = None) -> Any:
tgt = tgt.to(torch.int)
tgt_mask, memory_mask = make_tgt_and_memory_mask(
tgt, src, inputs_segmentation, targets_segmentation,
@@ -338,7 +347,7 @@ def forward(
if not decode:
tgt = shift_right(tgt)
tgt = self.shared_embedding(tgt)
- tgt = self.pos_encoder(tgt, targets_positions, decode=decode, cache=cache)
+ tgt = self.pos_encoder(tgt, targets_positions, decode=decode, cache=cache, dropout_rate=dropout_rate)
if decode:
tgt, cache = tgt
output = self.decoder(
@@ -348,7 +357,8 @@ def forward(
memory_mask=memory_mask,
decode=decode,
max_len=max_len,
- cache=cache)
+ cache=cache,
+ dropout_rate=dropout_rate)
if decode:
output, cache = output
normalize = math.sqrt(output.shape[-1])
@@ -362,10 +372,8 @@ class PositionalEncoding(nn.Module):
def __init__(self,
d_model: int,
- dropout_rate: float = 0.1,
max_len: int = 256):
super().__init__()
- self.dropout = nn.Dropout(p=dropout_rate)
position = torch.arange(max_len).unsqueeze(1)
scale_factor = -math.log(10000.0) / (d_model // 2 - 1)
@@ -380,7 +388,8 @@ def forward(
x: Tensor,
inputs_positions: Optional[Tensor] = None,
decode: bool = False,
- cache: Optional[Dict[str, Dict[str, Tensor]]] = None
+ cache: Optional[Dict[str, Dict[str, Tensor]]] = None,
+ dropout_rate: Optional[float] = None
) -> Union[Tensor, Tuple[Tensor, Dict[str, Dict[str, Tensor]]]]:
"""
Args:
@@ -403,14 +412,14 @@ def forward(
}
pe = self.pe[0, cache[name]['cache_index'], :]
cache[name]['cache_index'] += 1
- return self.dropout(x + pe), cache
+ return F.dropout(x + pe, dropout_rate, self.training), cache
if inputs_positions is None:
# normal unpacked case:
pe = self.pe[:, :x.size(1), :]
else:
# for packed data we need to use known position indices:
pe = self.pe[0, inputs_positions, :]
- return self.dropout(x + pe)
+ return F.dropout(x + pe, dropout_rate, self.training)
# TransformerEncoderLayer and TransformerDecoderLayer are taken from:
@@ -448,7 +457,6 @@ def __init__(self,
d_model: int = 1024,
nhead: int = 16,
dim_feedforward: int = 1024,
- dropout_rate: float = 0.1,
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
glu: bool = False,
layer_norm_eps: float = 1e-6,
@@ -462,7 +470,6 @@ def __init__(self,
d_model,
nhead,
self_attn=True,
- dropout_rate=dropout_rate,
attention_temp=attention_temp,
bias=False,
**factory_kwargs)
@@ -472,50 +479,55 @@ def __init__(self,
self.glu = glu
if self.glu:
self.linear_glu = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
- self.dropout = nn.Dropout(dropout_rate)
self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
self.pre_ln = pre_ln
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
- self.dropout1 = nn.Dropout(dropout_rate)
- self.dropout2 = nn.Dropout(dropout_rate)
self.activation = activation
- def forward(self, src: Tensor, src_mask: Optional[Tensor] = None) -> Tensor:
+ def forward(self,
+ src: Tensor,
+ src_mask: Optional[Tensor] = None,
+ dropout_rate: Optional[float] = None) -> Tensor:
r"""Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
src_mask: the mask for the src sequence (optional).
-
+ dropout_rate: the dropout probability value (optional).
Shape:
see the docs in Transformer class.
"""
x = src
if self.pre_ln:
- x = x + self._sa_block(self.norm1(x), src_mask)
- x = x + self._ff_block(self.norm2(x))
+ x = x + self._sa_block(self.norm1(x), src_mask, dropout_rate)
+ x = x + self._ff_block(self.norm2(x), dropout_rate)
else:
- x = self.norm1(x + self._sa_block(x, src_mask))
- x = self.norm2(x + self._ff_block(x))
+ x = self.norm1(x + self._sa_block(x, src_mask, dropout_rate))
+ x = self.norm2(x + self._ff_block(x, dropout_rate))
return x
# Self-attention block:
- def _sa_block(self, x: Tensor, attn_mask: Optional[Tensor]) -> Tensor:
- x, _ = self.self_attn(x, attn_mask=attn_mask)
- return self.dropout1(x)
+ def _sa_block(self,
+ x: Tensor,
+ attn_mask: Optional[Tensor],
+ dropout_rate: Optional[float] = None) -> Tensor:
+ x, _ = self.self_attn(x, attn_mask=attn_mask, dropout_rate=dropout_rate)
+ return F.dropout(x, dropout_rate, training=self.training)
# Feed forward block:
- def _ff_block(self, inputs: Tensor) -> Tensor:
+ def _ff_block(self,
+ inputs: Tensor,
+ dropout_rate: Optional[float] = None) -> Tensor:
x = self.activation(self.linear1(inputs))
if self.glu:
y = self.linear_glu(inputs)
x = x * y
- x = self.linear2(self.dropout(x))
- return self.dropout2(x)
+ x = self.linear2(F.dropout(x, dropout_rate, training=self.training))
+ return F.dropout(x, dropout_rate, training=self.training)
# Modified to use cache for autoregressive decoding and custom
@@ -527,7 +539,6 @@ class TransformerDecoder(nn.Module):
nhead: the number of heads in the multiheadattention models (default=16)
d_hid: the dimension of the feedforward network model
(default=1024)
- dropout_rate: the dropout_rate value (default=0.1)
layer_norm_eps: the eps value in layer normalization components
(default=1e-6).
decoder_layer: an instance of the TransformerDecoderLayer() class
@@ -545,7 +556,6 @@ def __init__(self,
d_model,
nhead,
d_hid,
- dropout_rate,
activation,
glu,
layer_norm_eps,
@@ -558,7 +568,6 @@ def __init__(self,
d_model,
nhead,
d_hid,
- dropout_rate,
activation,
glu,
layer_norm_eps=layer_norm_eps,
@@ -575,7 +584,8 @@ def forward(self,
memory_mask: Optional[Tensor] = None,
decode: bool = False,
max_len: Optional[int] = None,
- cache: Optional[dict] = None) -> Any:
+ cache: Optional[dict] = None,
+ dropout_rate: Optional[float] = None) -> Any:
r"""Pass the inputs (and mask) through the decoder layer in turn.
Args:
tgt: the sequence to the decoder (required).
@@ -584,6 +594,7 @@ def forward(self,
memory_mask: the mask for the memory sequence (optional).
decode: whether to use cache for autoregressive decoding or not.
max_len: maximum sequence length, necessary for decoding cache.
+ dropout_rate: the dropout probability value (optional)
Shape:
see the docs in Transformer class.
"""
@@ -598,7 +609,8 @@ def forward(self,
decode=decode,
max_len=max_len,
cache=cache,
- index=idx)
+ index=idx,
+ dropout_rate=dropout_rate)
if self.norm is not None:
output = self.norm(output)
@@ -624,7 +636,6 @@ class TransformerDecoderLayer(nn.Module):
nhead: the number of heads in the multiheadattention models (default=16).
dim_feedforward: the dimension of the feedforward network model
(default=1024).
- dropout_rate: the dropout_rate value (default=0.1).
activation: the activation function of the intermediate layer, can be a
string ("relu" or "gelu") or a unary callable (default=F.relu).
layer_norm_eps: the eps value in layer normalization components
@@ -644,7 +655,6 @@ def __init__(self,
d_model: int = 1024,
nhead: int = 16,
dim_feedforward: int = 1024,
- dropout_rate: float = 0.1,
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
glu: bool = False,
layer_norm_eps: float = 1e-6,
@@ -658,7 +668,6 @@ def __init__(self,
d_model,
nhead,
self_attn=True,
- dropout_rate=dropout_rate,
attention_temp=attention_temp,
bias=False,
**factory_kwargs)
@@ -666,7 +675,6 @@ def __init__(self,
d_model,
nhead,
self_attn=False,
- dropout_rate=dropout_rate,
attention_temp=attention_temp,
bias=False,
**factory_kwargs)
@@ -678,16 +686,12 @@ def __init__(self,
self.linear_glu = nn.Linear(dim_feedforward,
dim_feedforward,
**factory_kwargs)
- self.dropout = nn.Dropout(dropout_rate)
self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
self.pre_ln = pre_ln
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
- self.dropout1 = nn.Dropout(dropout_rate)
- self.dropout2 = nn.Dropout(dropout_rate)
- self.dropout3 = nn.Dropout(dropout_rate)
self.activation = activation
@@ -700,7 +704,8 @@ def forward( # pylint: disable=arguments-renamed
decode: bool = False,
max_len: Optional[int] = None,
cache: Optional[dict] = None,
- index: Optional[int] = None) -> Any:
+ index: Optional[int] = None,
+ dropout_rate: Optional[float] = None) -> Any:
r"""Pass the inputs (and mask) through the decoder layer.
Args:
tgt: the sequence to the decoder layer (required).
@@ -709,6 +714,7 @@ def forward( # pylint: disable=arguments-renamed
memory_mask: the mask for the memory sequence (optional).
decode: wether to use cache for autoregressive decoding or not.
max_len: maximum sequence length, necessary for decoding cache.
+ dropout_rate: the dropout probability value (optional)
Shape:
see the docs in Transformer class.
"""
@@ -722,10 +728,11 @@ def forward( # pylint: disable=arguments-renamed
decode=decode,
max_len=max_len,
cache=cache,
- index=index)
+ index=index,
+ dropout_rate=dropout_rate)
x = x + sa_out
- x = x + self._mha_block(self.norm2(x), memory, memory_mask)
- x = x + self._ff_block(self.norm3(x))
+ x = x + self._mha_block(self.norm2(x), memory, memory_mask, dropout_rate)
+ x = x + self._ff_block(self.norm3(x), dropout_rate)
else:
sa_out, cache = self._sa_block(
x,
@@ -733,10 +740,11 @@ def forward( # pylint: disable=arguments-renamed
decode=decode,
max_len=max_len,
cache=cache,
- index=index)
+ index=index,
+ dropout_rate=dropout_rate)
x = self.norm1(x + sa_out)
- x = self.norm2(x + self._mha_block(x, memory, memory_mask))
- x = self.norm3(x + self._ff_block(x))
+ x = self.norm2(x + self._mha_block(x, memory, memory_mask, dropout_rate))
+ x = self.norm3(x + self._ff_block(x, dropout_rate))
return x, cache
@@ -748,30 +756,38 @@ def _sa_block( # pylint: disable=arguments-renamed
decode: bool = False,
max_len: Optional[int] = None,
cache: Optional[dict] = None,
- index: Optional[int] = None) -> Any:
+ index: Optional[int] = None,
+ dropout_rate: Optional[float] = None) -> Any:
x, cache = self.self_attn(
x,
attn_mask=attn_mask,
decode=decode,
max_len=max_len,
cache=cache,
- index=index)
- return self.dropout1(x), cache
+ index=index,
+ dropout_rate=dropout_rate)
+ return F.dropout(x, dropout_rate, self.training), cache
# Multihead attention block:
def _mha_block(self, x: Tensor, mem: Tensor,
- attn_mask: Optional[Tensor]) -> Tensor:
- x, _ = self.multihead_attn(x, mem, attn_mask=attn_mask)
- return self.dropout2(x)
+ attn_mask: Optional[Tensor],
+ dropout_rate: Optional[float] = None) -> Tensor:
+ x, _ = self.multihead_attn(
+ x,
+ mem,
+ attn_mask=attn_mask,
+ dropout_rate=dropout_rate)
+ return F.dropout(x, dropout_rate, self.training)
# Feed forward block.
- def _ff_block(self, inputs: Tensor) -> Tensor:
+ def _ff_block(self, inputs: Tensor,
+ dropout_rate: Optional[float] = None) -> Tensor:
x = self.activation(self.linear1(inputs))
if self.glu:
y = self.linear_glu(inputs)
x = x * y
- x = self.linear2(self.dropout(x))
- return self.dropout3(x)
+ x = self.linear2(F.dropout(x, dropout_rate, self.training))
+ return F.dropout(x, dropout_rate, self.training)
class MultiheadAttention(nn.Module):
@@ -789,8 +805,6 @@ class MultiheadAttention(nn.Module):
``embed_dim // num_heads``).
self_attn: Whether self attention or encoder-decoder attention is used.
Default: ``True``.
- dropout_rate: Dropout probability on ``attn_output_weights``.
- Default: ``0.0`` (no dropout_rate).
bias: If specified, adds bias to input / output projection layers.
Default: ``False``.
device: The device of the module.
@@ -804,7 +818,6 @@ def __init__(self,
embed_dim: int,
num_heads: int,
self_attn: bool = True,
- dropout_rate: float = 0.,
attention_temp: float = 1.0,
bias: bool = False,
device: Optional[torch.device] = None,
@@ -813,7 +826,6 @@ def __init__(self,
self.embed_dim = embed_dim
self.num_heads = num_heads
self.self_attn = self_attn
- self.dropout = dropout_rate
self.head_dim = embed_dim // num_heads
self.attention_temp = attention_temp
assert self.head_dim * num_heads == self.embed_dim, \
@@ -848,7 +860,8 @@ def forward(self,
decode: bool = False,
max_len: Optional[int] = None,
cache: Optional[dict] = None,
- index: Optional[int] = None) -> Any:
+ index: Optional[int] = None,
+ dropout_rate: Optional[float] = None) -> Any:
r"""
Args:
x: Batch of input sequences of shape
@@ -874,6 +887,7 @@ def forward(self,
max_len: maximum sequence length, necessary for decoding cache.
cache: cache dictionary for autoregressive decoding.
index: index of the current decoding step, necessary for decoding cache.
+ dropout_rate: dropout probability on ``attn_output_weights``.
Outputs:
- **attn_output** - Attention outputs of shape :math:`(N, L, E)`, where
:math:`L` is the target sequence length, :math:`N` is the batch size,
@@ -963,12 +977,12 @@ def forward(self,
attn_mask = new_attn_mask
# Adjust dropout_rate probability.
- dropout_rate = self.dropout if self.training else 0.0
+ attn_dropout_rate = dropout_rate if self.training else 0.0
# Calculate attention.
q = self.attention_temp * q
attn_output = torch.nn.functional.scaled_dot_product_attention(
- q, k, v, attn_mask, dropout_rate)
+ q, k, v, attn_mask, attn_dropout_rate)
# Rearrange for output projection.
attn_output = attn_output.transpose(1, 2).contiguous().view(
bsz, tgt_len, embed_dim)
From af08bb91f93266d12e6ffeab7045b7c57cb9143f Mon Sep 17 00:00:00 2001
From: Niccolo-Ajroldi
Date: Wed, 11 Jun 2025 17:19:44 +0200
Subject: [PATCH 036/123] fix dropout, ALL tested
---
.../criteo1tb_pytorch/models_dropout.py | 35 +-
.../models_functional_dropout.py | 308 ------------------
.../fastmri/fastmri_pytorch/models_dropout.py | 13 +-
.../librispeech_pytorch/models_dropout.py | 2 +-
.../librispeech_pytorch/models_dropout.py | 2 +-
5 files changed, 37 insertions(+), 323 deletions(-)
delete mode 100644 algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_functional_dropout.py
diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py
index 8042ec31e..d8d7393e4 100644
--- a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py
+++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py
@@ -79,6 +79,10 @@ def __init__(self,
self.mlp_bottom_dims = mlp_bottom_dims
self.mlp_top_dims = mlp_top_dims
self.embed_dim = embed_dim
+ if dropout_rate is None:
+ self.dropout_rate = 0.0
+ else:
+ self.dropout_rate = dropout_rate
# Ideally, we should use the pooled embedding implementation from
# `TorchRec`. However, in order to have identical implementation
@@ -127,17 +131,16 @@ def __init__(self,
block.append(nn.Linear(fan_in, fan_out))
if layer_idx < (num_layers_top - 1):
block.append(nn.ReLU(inplace=True))
- if (dropout_rate is not None and dropout_rate > 0.0 and
- layer_idx == num_layers_top - 2):
- block.append(CustomDropout()) # (nico)
- block = SequentialWithDropout(*block) # (nico)
+ if layer_idx == num_layers_top - 2:
+ block.append(CustomDropout())
+ block = SequentialWithDropout(*block)
if (layer_idx != 0) and (layer_idx != num_layers_top - 1):
block = DenseBlockWithDropout(block, resnet=True)
else:
block = DenseBlockWithDropout(block)
mlp_top_blocks.append(block)
fan_in = fan_out
- self.top_mlp = SequentialWithDropout(*mlp_top_blocks) # (nico)
+ self.top_mlp = SequentialWithDropout(*mlp_top_blocks)
for module in self.top_mlp.modules():
if isinstance(module, nn.Linear):
@@ -149,7 +152,10 @@ def __init__(self,
0.,
math.sqrt(1. / module.out_features))
- def forward(self, x, dropout_rate):
+ def forward(self, x, dropout_rate=None):
+ if dropout_rate is None:
+ dropout_rate = self.dropout_rate
+
batch_size = x.shape[0]
dense_features, sparse_features = torch.split(
@@ -201,6 +207,11 @@ def __init__(self,
self.mlp_top_dims = mlp_top_dims
self.embed_dim = embed_dim
self.embedding_init_multiplier = embedding_init_multiplier
+ self.dropout_rate = dropout_rate
+ if dropout_rate is None:
+ self.dropout_rate = 0.0
+ else:
+ self.dropout_rate = dropout_rate
# Ideally, we should use the pooled embedding implementation from
# `TorchRec`. However, in order to have identical implementation
@@ -253,10 +264,9 @@ def __init__(self,
top_mlp_layers.append(nn.ReLU(inplace=True))
if use_layer_norm:
top_mlp_layers.append(nn.LayerNorm(fan_out, eps=1e-6))
- if (dropout_rate is not None and dropout_rate > 0.0 and
- layer_idx == num_layers_top - 2):
- top_mlp_layers.append(CustomDropout()) # (nico)
- self.top_mlp = SequentialWithDropout(*top_mlp_layers) # (nico)
+ if layer_idx == num_layers_top - 2:
+ top_mlp_layers.append(CustomDropout())
+ self.top_mlp = SequentialWithDropout(*top_mlp_layers)
if use_layer_norm:
self.embed_ln = nn.LayerNorm(self.embed_dim, eps=1e-6)
else:
@@ -271,7 +281,10 @@ def __init__(self,
0.,
math.sqrt(1. / module.out_features))
- def forward(self, x, dropout_rate):
+ def forward(self, x, dropout_rate=None):
+ if dropout_rate is None:
+ dropout_rate = self.dropout_rate
+
batch_size = x.shape[0]
dense_features, sparse_features = torch.split(
diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_functional_dropout.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_functional_dropout.py
deleted file mode 100644
index 346e0e72a..000000000
--- a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_functional_dropout.py
+++ /dev/null
@@ -1,308 +0,0 @@
-"""Pytorch implementation of DLRM-Small."""
-
-import math
-
-import torch
-import torch.nn.functional as F
-from torch import nn
-
-
-class DenseBlock(nn.Module):
- """Dense block with optional residual connection.""" ""
-
- def __init__(self, module, resnet=False):
- super().__init__()
- self.module = module
- self.resnet = resnet
-
- def forward(self, x):
- if self.resnet:
- return self.module(x) + x
- else:
- return self.module(x)
-
-
-class DotInteract(nn.Module):
- """Performs feature interaction operation between dense or sparse features."""
-
- def __init__(self, num_sparse_features):
- super().__init__()
- self.triu_indices = torch.triu_indices(num_sparse_features + 1,
- num_sparse_features + 1)
-
- def forward(self, dense_features, sparse_features):
- combined_values = torch.cat((dense_features.unsqueeze(1), sparse_features),
- dim=1)
- interactions = torch.bmm(combined_values,
- torch.transpose(combined_values, 1, 2))
- interactions_flat = interactions[:,
- self.triu_indices[0],
- self.triu_indices[1]]
- return torch.cat((dense_features, interactions_flat), dim=1)
-
-
-class DLRMResNet(nn.Module):
- """Define a DLRM-Small model.
-
- Parameters:
- vocab_size: vocab size of embedding table.
- num_dense_features: number of dense features as the bottom mlp input.
- mlp_bottom_dims: dimensions of dense layers of the bottom mlp.
- mlp_top_dims: dimensions of dense layers of the top mlp.
- embed_dim: embedding dimension.
- """
-
- def __init__(self,
- vocab_size,
- num_dense_features=13,
- num_sparse_features=26,
- mlp_bottom_dims=(256, 256, 256),
- mlp_top_dims=(256, 256, 256, 256, 1),
- embed_dim=128,
- # dropout_rate=0.0,
- use_layer_norm=False,
- embedding_init_multiplier=None):
- super().__init__()
- self.vocab_size = torch.tensor(vocab_size, dtype=torch.int32)
- self.num_dense_features = num_dense_features
- self.num_sparse_features = num_sparse_features
- self.mlp_bottom_dims = mlp_bottom_dims
- self.mlp_top_dims = mlp_top_dims
- self.embed_dim = embed_dim
-
- # Ideally, we should use the pooled embedding implementation from
- # `TorchRec`. However, in order to have identical implementation
- # with that of Jax, we define a single embedding matrix.
- num_chunks = 4
- assert vocab_size % num_chunks == 0
- self.embedding_table_chucks = []
- scale = 1.0 / torch.sqrt(self.vocab_size)
- for i in range(num_chunks):
- chunk = nn.Parameter(
- torch.Tensor(self.vocab_size // num_chunks, self.embed_dim))
- chunk.data.uniform_(0, 1)
- chunk.data = scale * chunk.data
- self.register_parameter(f'embedding_chunk_{i}', chunk)
- self.embedding_table_chucks.append(chunk)
-
- input_dim = self.num_dense_features
- bot_mlp_blocks = []
- for layer_idx, dense_dim in enumerate(self.mlp_bottom_dims):
- block = []
- block.append(nn.Linear(input_dim, dense_dim))
- block.append(nn.ReLU(inplace=True))
- block = nn.Sequential(*block)
- if layer_idx > 0:
- block = DenseBlock(block, resnet=True)
- else:
- block = DenseBlock(block)
- bot_mlp_blocks.append(block)
- input_dim = dense_dim
- self.bot_mlp = nn.Sequential(*bot_mlp_blocks)
-
- for module in self.bot_mlp.modules():
- if isinstance(module, nn.Linear):
- limit = math.sqrt(6. / (module.in_features + module.out_features))
- nn.init.uniform_(module.weight.data, -limit, limit)
- nn.init.normal_(module.bias.data,
- 0.,
- math.sqrt(1. / module.out_features))
-
- # Number of sparse features = 26
- fan_in = (26 * self.embed_dim) + self.mlp_bottom_dims[-1]
- num_layers_top = len(self.mlp_top_dims)
- mlp_top_blocks = []
- for layer_idx, fan_out in enumerate(self.mlp_top_dims):
- block = []
- block.append(nn.Linear(fan_in, fan_out))
- if layer_idx < (num_layers_top - 1):
- block.append(nn.ReLU(inplace=True))
- # if (dropout_rate is not None and dropout_rate > 0.0 and
- # layer_idx == num_layers_top - 2):
- # block.append(nn.Dropout(p=dropout_rate))
- block = nn.Sequential(*block)
- if (layer_idx != 0) and (layer_idx != num_layers_top - 1):
- block = DenseBlock(block, resnet=True)
- else:
- block = DenseBlock(block)
- mlp_top_blocks.append(block)
- fan_in = fan_out
- self.top_mlp = nn.Sequential(*mlp_top_blocks)
-
- for module in self.top_mlp.modules():
- if isinstance(module, nn.Linear):
- nn.init.normal_(
- module.weight.data,
- 0.,
- math.sqrt(2. / (module.in_features + module.out_features)))
- nn.init.normal_(module.bias.data,
- 0.,
- math.sqrt(1. / module.out_features))
-
- def forward(self, x, dropout_rate):
- batch_size = x.shape[0]
-
- dense_features, sparse_features = torch.split(
- x, [self.num_dense_features, self.num_sparse_features], 1)
-
- # Bottom MLP.
- embedded_dense = self.bot_mlp(dense_features)
-
- # Sparse feature processing.
- sparse_features = sparse_features.to(dtype=torch.int32)
- idx_lookup = torch.reshape(sparse_features, [-1]) % self.vocab_size
- embedding_table = torch.cat(self.embedding_table_chucks, dim=0)
- embedded_sparse = embedding_table[idx_lookup]
- embedded_sparse = torch.reshape(embedded_sparse,
- [batch_size, 26 * self.embed_dim])
- top_mlp_input = torch.cat([embedded_dense, embedded_sparse], axis=1)
-
- # Final MLP (horrible!!).
- h = top_mlp_input
- num_layers_top = len(self.mlp_top_dims)
- for layer_idx, block in enumerate(self.top_mlp):
- # block.module is nn.Sequential([...])
- seq = block.module
- # 1) linear
- out = seq[0](h)
- # 2) ReLU (if present)
- if layer_idx < (num_layers_top - 1):
- out = seq[1](out)
- # 3) functional dropout at penult layer
- if dropout_rate > 0 and layer_idx == num_layers_top - 2:
- out = F.dropout(out, dropout_rate, training=self.training)
- # 4) wrap in residual if needed
- h = out + h if block.resnet else out
- return h
-
-
-class DlrmSmall(nn.Module):
- """Define a DLRM-Small model.
-
- Parameters:
- vocab_size: vocab size of embedding table.
- num_dense_features: number of dense features as the bottom mlp input.
- mlp_bottom_dims: dimensions of dense layers of the bottom mlp.
- mlp_top_dims: dimensions of dense layers of the top mlp.
- embed_dim: embedding dimension.
- """
-
- def __init__(self,
- vocab_size,
- num_dense_features=13,
- num_sparse_features=26,
- mlp_bottom_dims=(512, 256, 128),
- mlp_top_dims=(1024, 1024, 512, 256, 1),
- embed_dim=128,
- # dropout_rate=0.0,
- use_layer_norm=False,
- embedding_init_multiplier=None):
- super().__init__()
- self.vocab_size = torch.tensor(vocab_size, dtype=torch.int32)
- self.num_dense_features = num_dense_features
- self.num_sparse_features = num_sparse_features
- self.mlp_bottom_dims = mlp_bottom_dims
- self.mlp_top_dims = mlp_top_dims
- self.embed_dim = embed_dim
- self.embedding_init_multiplier = embedding_init_multiplier
-
- # Ideally, we should use the pooled embedding implementation from
- # `TorchRec`. However, in order to have identical implementation
- # with that of Jax, we define a single embedding matrix.
- num_chunks = 4
- assert vocab_size % num_chunks == 0
- self.embedding_table_chucks = []
-
- if self.embedding_init_multiplier is None:
- scale = 1.0 / torch.sqrt(self.vocab_size)
- else:
- scale = self.embedding_init_multiplier
-
- for i in range(num_chunks):
- chunk = nn.Parameter(
- torch.Tensor(self.vocab_size // num_chunks, self.embed_dim))
- chunk.data.uniform_(0, 1)
- chunk.data = scale * chunk.data
- self.register_parameter(f'embedding_chunk_{i}', chunk)
- self.embedding_table_chucks.append(chunk)
-
- input_dim = self.num_dense_features
- bottom_mlp_layers = []
- for dense_dim in self.mlp_bottom_dims:
- bottom_mlp_layers.append(nn.Linear(input_dim, dense_dim))
- bottom_mlp_layers.append(nn.ReLU(inplace=True))
- if use_layer_norm:
- bottom_mlp_layers.append(nn.LayerNorm(dense_dim, eps=1e-6))
- input_dim = dense_dim
- self.bot_mlp = nn.Sequential(*bottom_mlp_layers)
- for module in self.bot_mlp.modules():
- if isinstance(module, nn.Linear):
- limit = math.sqrt(6. / (module.in_features + module.out_features))
- nn.init.uniform_(module.weight.data, -limit, limit)
- nn.init.normal_(module.bias.data,
- 0.,
- math.sqrt(1. / module.out_features))
-
- self.dot_interact = DotInteract(num_sparse_features=num_sparse_features,)
-
- # TODO: Write down the formula here instead of the constant.
- input_dims = 506
- num_layers_top = len(self.mlp_top_dims)
- top_mlp_layers = []
- for layer_idx, fan_out in enumerate(self.mlp_top_dims):
- fan_in = input_dims if layer_idx == 0 \
- else self.mlp_top_dims[layer_idx - 1]
- top_mlp_layers.append(nn.Linear(fan_in, fan_out))
- if layer_idx < (num_layers_top - 1):
- top_mlp_layers.append(nn.ReLU(inplace=True))
- if use_layer_norm:
- top_mlp_layers.append(nn.LayerNorm(fan_out, eps=1e-6))
- # if (dropout_rate is not None and dropout_rate > 0.0 and
- # layer_idx == num_layers_top - 2):
- # top_mlp_layers.append(nn.Dropout(p=dropout_rate))
- self.top_mlp = nn.Sequential(*top_mlp_layers)
- if use_layer_norm:
- self.embed_ln = nn.LayerNorm(self.embed_dim, eps=1e-6)
- else:
- self.embed_ln = None
- for module in self.top_mlp.modules():
- if isinstance(module, nn.Linear):
- nn.init.normal_(
- module.weight.data,
- 0.,
- math.sqrt(2. / (module.in_features + module.out_features)))
- nn.init.normal_(module.bias.data,
- 0.,
- math.sqrt(1. / module.out_features))
-
- def forward(self, x, dropout_rate):
- batch_size = x.shape[0]
-
- dense_features, sparse_features = torch.split(
- x, [self.num_dense_features, self.num_sparse_features], 1)
-
- # Bottom MLP.
- embedded_dense = self.bot_mlp(dense_features)
-
- # Sparse feature processing.
- sparse_features = sparse_features.to(dtype=torch.int32)
- idx_lookup = torch.reshape(sparse_features, [-1]) % self.vocab_size
- embedding_table = torch.cat(self.embedding_table_chucks, dim=0)
- embedded_sparse = embedding_table[idx_lookup]
- embedded_sparse = torch.reshape(embedded_sparse,
- [batch_size, -1, self.embed_dim])
- if self.embed_ln:
- embedded_sparse = self.embed_ln(embedded_sparse)
- # Dot product interactions.
- concatenated_dense = self.dot_interact(
- dense_features=embedded_dense, sparse_features=embedded_sparse)
-
- # Final MLP: run each layer, and after the penultimate layer do functional dropout
- h = concatenated_dense
- N = len(self.top_mlp)
- for idx, layer in enumerate(self.top_mlp):
- h = layer(h)
- # insert dropout exactly where nn.Dropout used to live
- if dropout_rate > 0 and idx == N - 2:
- h = F.dropout(h, dropout_rate, training=self.training)
- return h
diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py b/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py
index 5862f6352..8954cb737 100644
--- a/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py
+++ b/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py
@@ -29,6 +29,7 @@ def __init__(self,
out_chans: int = 1,
num_channels: int = 32,
num_pool_layers: int = 4,
+ dropout_rate: Optional[float] = 0.0,
use_tanh: bool = False,
use_layer_norm: bool = False) -> None:
super().__init__()
@@ -37,6 +38,11 @@ def __init__(self,
self.out_chans = out_chans
self.num_channels = num_channels
self.num_pool_layers = num_pool_layers
+ if dropout_rate is None:
+ self.dropout_rate = 0.0
+ else:
+ self.dropout_rate = dropout_rate
+
self.down_sample_layers = nn.ModuleList([
ConvBlock(in_chans,
num_channels,
@@ -72,7 +78,10 @@ def __init__(self,
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
init_utils.pytorch_default_init(m)
- def forward(self, x: Tensor, dropout_rate: float) -> Tensor:
+ def forward(self, x: Tensor, dropout_rate: Optional[float] = None) -> Tensor:
+ if dropout_rate is None:
+ dropout_rate = self.dropout_rate
+
stack = []
output = x
@@ -136,7 +145,7 @@ def __init__(self,
CustomDropout2d(),
)
- def forward(self, x: Tensor, dropout_rate: float) -> Tensor:
+ def forward(self, x: Tensor, dropout_rate: Optional[float] = None) -> Tensor:
return self.conv_layers(x, dropout_rate)
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py
index da66dfe43..9ff662fb8 100644
--- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py
+++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py
@@ -510,7 +510,7 @@ def forward(self, inputs, input_paddings, dropout_rate=None):
outputs, output_paddings = self.preprocessor(outputs, output_paddings)
if self.training and self.config.use_specaug:
outputs, output_paddings = self.specaug(outputs, output_paddings)
- outputs, output_paddings = self.subsample(outputs, output_paddings)
+ outputs, output_paddings = self.subsample(outputs, output_paddings, dropout_rate)
for conformer in self.conformers:
outputs = conformer(outputs, output_paddings, dropout_rate)
outputs = self.ln(outputs)
diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py
index e68a820ed..8797aa578 100644
--- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py
+++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py
@@ -383,7 +383,7 @@ def forward(self, inputs, input_paddings, dropout_rate=None):
for idx in range(self.config.num_ffn_layers):
if self.config.enable_residual_connections:
- outputs = outputs + self.ffns[idx](outputs, output_paddings)
+ outputs = outputs + self.ffns[idx](outputs, output_paddings, dropout_rate)
else:
outputs = self.ffns[idx](outputs, output_paddings, dropout_rate)
From 7a6651a69953af4655e0f7c50b8c7fefe71aa9ca Mon Sep 17 00:00:00 2001
From: Niccolo-Ajroldi
Date: Wed, 11 Jun 2025 17:20:39 +0200
Subject: [PATCH 037/123] add dropout equivalence tests
---
.../test_model_equivalence.py | 77 ++++++++++++
.../fastmri_pytorch/test_model_equivalence.py | 98 +++++++++++++++
.../test_model_equivalence.py | 112 ++++++++++++++++++
.../test_model_equivalence.py | 91 ++++++++++++++
.../test_model_equivalence.py | 89 ++++++++++++++
.../ogbg_pytorch/test_model_equivalence.py | 76 ++++++++++++
.../wmt_pytorch/test_model_equivalence.py | 83 +++++++++++++
7 files changed, 626 insertions(+)
create mode 100644 tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py
create mode 100644 tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py
create mode 100644 tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py
create mode 100644 tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py
create mode 100644 tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py
create mode 100644 tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py
create mode 100644 tests/dropout_fix/wmt_pytorch/test_model_equivalence.py
diff --git a/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py b/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py
new file mode 100644
index 000000000..b9b1232ef
--- /dev/null
+++ b/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py
@@ -0,0 +1,77 @@
+"""
+Runs fwd pass with random input for our DLRM models and compares outputs.
+Run it as:
+ python3 tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py
+"""
+
+from absl.testing import absltest, parameterized
+from torch.testing import assert_close
+import torch
+import os
+
+from algoperf.workloads.criteo1tb.criteo1tb_pytorch.models import (
+ DLRMResNet as OriginalDLRMResNet,
+ DlrmSmall as OriginalDlrmSmall,
+)
+from algoperf.workloads.criteo1tb.criteo1tb_pytorch.models_dropout import (
+ DLRMResNet as CustomDLRMResNet,
+ DlrmSmall as CustomDlrmSmall,
+)
+
+
+BATCH, DENSE, SPARSE = 16, 13, 26
+FEATURES = DENSE + SPARSE
+VOCAB = 1000
+DEVICE = 'cuda'
+TORCH_COMPILE = False
+SEED = 1996
+
+os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
+torch.backends.cudnn.benchmark = False
+torch.backends.cudnn.deterministic = True
+torch.use_deterministic_algorithms(True)
+
+class ModelEquivalenceTest(parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ dict(testcase_name='DLRMResNet, p=None', model='dlrm_resnet', dropout_rate=None),
+ dict(testcase_name='DlrmSmall, p=None', model='dlrm_small', dropout_rate=None),
+ dict(testcase_name='DLRMResNet, p=0.0', model='dlrm_resnet', dropout_rate=0.0),
+ dict(testcase_name='DlrmSmall, p=0.0', model='dlrm_small', dropout_rate=0.0),
+ dict(testcase_name='DLRMResNet, p=0.1', model='dlrm_resnet', dropout_rate=0.1),
+ dict(testcase_name='DlrmSmall, p=0.1', model='dlrm_small', dropout_rate=0.1),
+ dict(testcase_name='DLRMResNet, p=1.0', model='dlrm_resnet', dropout_rate=1.0),
+ dict(testcase_name='DlrmSmall, p=1.0', model='dlrm_small', dropout_rate=1.0),
+ )
+ def test_forward(self, model, dropout_rate):
+ OrigCls, CustCls = (
+ (OriginalDLRMResNet, CustomDLRMResNet)
+ if model == 'dlrm_resnet'
+ else (OriginalDlrmSmall, CustomDlrmSmall)
+ )
+
+ # Test initalizing custom model with a None dropout_rate
+ for custom_init_dropout_rate in [dropout_rate, None]:
+
+ torch.manual_seed(SEED)
+ orig = OrigCls(vocab_size=VOCAB, dropout_rate=dropout_rate)
+ orig.to(DEVICE)
+
+ torch.manual_seed(SEED)
+ cust = CustCls(vocab_size=VOCAB, dropout_rate=custom_init_dropout_rate)
+ cust.to(DEVICE)
+
+ if TORCH_COMPILE:
+ orig = torch.compile(orig); cust = torch.compile(cust)
+
+ x = torch.randn(BATCH, FEATURES, device=DEVICE)
+
+ for mode in ('train', 'eval'):
+ getattr(orig, mode)(); getattr(cust, mode)()
+ torch.manual_seed(SEED); y1 = orig(x)
+ torch.manual_seed(SEED); y2 = cust(x, dropout_rate)
+ assert_close(y1, y2, atol=0, rtol=0)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py b/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py
new file mode 100644
index 000000000..6339ff21b
--- /dev/null
+++ b/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py
@@ -0,0 +1,98 @@
+"""
+Runs fwd pass with random input for FASTMRI U-Net models and compares outputs.
+Run it as:
+ python3 tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py
+"""
+
+from absl.testing import absltest, parameterized
+from torch.testing import assert_close
+import torch
+import os
+
+from algoperf.workloads.fastmri.fastmri_pytorch.models import UNet as OriginalUNet
+from algoperf.workloads.fastmri.fastmri_pytorch.models_dropout import UNet as CustomUNet
+
+BATCH, IN_CHANS, H, W = 4, 1, 256, 256
+OUT_CHANS, C, LAYERS = 1, 32, 4
+DEVICE = 'cuda'
+TORCH_COMPILE = True
+SEED = 1996
+
+os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
+torch.backends.cudnn.benchmark = False
+torch.backends.cudnn.deterministic = True
+torch.use_deterministic_algorithms(True)
+
+class FastMRIModeEquivalenceTest(parameterized.TestCase):
+
+ def fwd_pass(self, orig, cust, dropout_rate):
+ x = torch.randn(BATCH, IN_CHANS, H, W, device=DEVICE)
+ for mode in ('train', 'eval'):
+ getattr(orig, mode)()
+ getattr(cust, mode)()
+ torch.manual_seed(0)
+ y1 = orig(x)
+ torch.manual_seed(0)
+ y2 = cust(x, dropout_rate)
+ assert_close(y1, y2, atol=0, rtol=0)
+
+ @parameterized.named_parameters(
+ dict(testcase_name='p=None', dropout_rate=None),
+ dict(testcase_name='p=0.0', dropout_rate=0.0),
+ dict(testcase_name='p=0.1', dropout_rate=0.1),
+ dict(testcase_name='p=1.0', dropout_rate=1.0),
+ )
+ def test_dropout_values(self, dropout_rate):
+ """Test different values of dropout_rate."""
+
+ # Test initalizing custom model with a None dropout_rate
+ for custom_init_dropout_rate in [dropout_rate, None]:
+
+ torch.manual_seed(SEED)
+ orig = OriginalUNet(
+ IN_CHANS, OUT_CHANS, C, LAYERS, dropout_rate=dropout_rate
+ ).to(DEVICE)
+
+ torch.manual_seed(SEED)
+ cust = CustomUNet(
+ IN_CHANS, OUT_CHANS, C, LAYERS, dropout_rate=custom_init_dropout_rate
+ ).to(DEVICE)
+
+ cust.load_state_dict(orig.state_dict()) # sync weights
+ if TORCH_COMPILE:
+ orig = torch.compile(orig); cust = torch.compile(cust)
+
+ self.fwd_pass(orig, cust, dropout_rate)
+
+
+ @parameterized.named_parameters(
+ dict(testcase_name='default', use_tanh=False, use_layer_norm=False),
+ dict(testcase_name='tanh', use_tanh=True, use_layer_norm=False),
+ dict(testcase_name='layer_norm', use_tanh=False, use_layer_norm=True),
+ dict(testcase_name='both', use_tanh=True, use_layer_norm=True),
+ )
+ def test_arch_setups(self, use_tanh, use_layer_norm):
+ """Test different architecture configurations, fixed dropout_rate."""
+ dropout_rate = 0.1
+
+ torch.manual_seed(SEED)
+ orig = OriginalUNet(
+ IN_CHANS, OUT_CHANS, C, LAYERS, dropout_rate=dropout_rate,
+ use_tanh=use_tanh, use_layer_norm=use_layer_norm
+ ).to(DEVICE)
+
+ torch.manual_seed(SEED)
+ cust = CustomUNet(
+ IN_CHANS, OUT_CHANS, C, LAYERS,
+ use_tanh=use_tanh, use_layer_norm=use_layer_norm
+ ).to(DEVICE)
+
+ cust.load_state_dict(orig.state_dict()) # sync weights
+ if TORCH_COMPILE:
+ orig = torch.compile(orig); cust = torch.compile(cust)
+
+ self.fwd_pass(orig, cust, dropout_rate)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py b/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py
new file mode 100644
index 000000000..56644f152
--- /dev/null
+++ b/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py
@@ -0,0 +1,112 @@
+"""
+Runs fwd pass with random input for FASTMRI U-Net models and compares outputs.
+Run it as:
+ python3 tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py
+"""
+
+from absl.testing import absltest, parameterized
+from torch.testing import assert_close
+import torch
+import os
+import itertools
+
+from algoperf.workloads.imagenet_vit.imagenet_pytorch.models import ViT as OriginalVit
+from algoperf.workloads.imagenet_vit.imagenet_pytorch.models_dropout import ViT as CustomVit
+
+# Model / test hyper-params
+BATCH, C, H, W = 4, 3, 224, 224 # input shape (N,C,H,W)
+WIDTH, DEPTH, HEADS = 256, 4, 8
+DROPOUT_RATE = None
+DEVICE = 'cuda'
+SEED = 1996
+
+os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
+torch.backends.cudnn.benchmark = False
+torch.backends.cudnn.deterministic = True
+torch.use_deterministic_algorithms(True)
+
+class ImageNetVitModeEquivalenceTest(parameterized.TestCase):
+
+ def fwd_pass(self, orig, cust, dropout_rate):
+ x = torch.randn(BATCH, C, H, W, device=DEVICE)
+ for mode in ('train', 'eval'):
+ getattr(orig, mode)()
+ getattr(cust, mode)()
+ torch.manual_seed(0); y1 = orig(x)
+ torch.manual_seed(0); y2 = cust(x, dropout_rate)
+ assert_close(y1, y2, atol=0, rtol=0)
+
+ @parameterized.named_parameters(
+ dict(testcase_name='p=None', dropout_rate=None),
+ dict(testcase_name='p=0.0', dropout_rate=0.0),
+ dict(testcase_name='p=0.1', dropout_rate=0.1),
+ dict(testcase_name='p=0.6', dropout_rate=0.6),
+ dict(testcase_name='p=1.0', dropout_rate=1.0),
+ )
+ def test_dropout_values(self, dropout_rate):
+ """Test different dropout_values."""
+
+ # Test initalizing custom model with a None dropout_rate
+ for custom_init_dropout_rate in [dropout_rate, None]:
+
+ torch.manual_seed(SEED)
+ orig = OriginalVit(
+ width=WIDTH,
+ depth=DEPTH,
+ num_heads=HEADS,
+ dropout_rate=dropout_rate,
+ ).to(DEVICE)
+
+ torch.manual_seed(SEED)
+ cust = CustomVit(
+ width=WIDTH,
+ depth=DEPTH,
+ num_heads=HEADS,
+ dropout_rate=custom_init_dropout_rate,
+ ).to(DEVICE)
+
+ cust.load_state_dict(orig.state_dict()) # sync weights
+ self.fwd_pass(orig, cust, dropout_rate)
+
+
+ @parameterized.named_parameters([
+ dict(
+ testcase_name=f"GLU={use_glu}_LN={use_post_ln}_MAP={use_map}",
+ use_glu=use_glu,
+ use_post_ln=use_post_ln,
+ use_map=use_map,
+ )
+ for use_glu, use_post_ln, use_map in itertools.product([False, True], repeat=3)
+ ])
+ def test_arch(self, use_glu, use_post_ln, use_map):
+ """Test different architecture configurations, fixed dropout_rate."""
+ dropout_rate = 0.1
+
+ torch.manual_seed(SEED)
+ orig = OriginalVit(
+ width=WIDTH,
+ depth=DEPTH,
+ num_heads=HEADS,
+ use_glu=use_glu,
+ use_post_layer_norm=use_post_ln,
+ use_map=use_map,
+ dropout_rate=dropout_rate,
+ ).to(DEVICE)
+
+ torch.manual_seed(SEED)
+ cust = CustomVit(
+ width=WIDTH,
+ depth=DEPTH,
+ num_heads=HEADS,
+ use_glu=use_glu,
+ use_post_layer_norm=use_post_ln,
+ use_map=use_map,
+ dropout_rate=None,
+ ).to(DEVICE)
+
+ cust.load_state_dict(orig.state_dict()) # sync weights
+ self.fwd_pass(orig, cust, dropout_rate)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py b/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py
new file mode 100644
index 000000000..19525a98b
--- /dev/null
+++ b/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py
@@ -0,0 +1,91 @@
+"""
+Runs fwd pass with random input for LIBRISPEECH Conformer models and compares outputs.
+Run with:
+ python3 tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py
+
+`dropout_rate` controls the following args:
+- `attention_residual_dropout_rate` (if None, 0.1
+- `conv_residual_dropout_rate` (if None, 0.0)
+- `feed_forward_residual_dropout_rate` (if None, 0.1)
+- `input_dropout_rate` (if None, 0.1)
+"""
+
+from absl.testing import absltest, parameterized
+from torch.testing import assert_close
+import torch
+import os
+
+from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import (
+ # ConformerConfig,
+ ConformerEncoderDecoder as OriginalConf
+)
+from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models_dropout import(
+ ConformerEncoderDecoder as CustomConf,
+ ConformerConfig,
+)
+
+B, T = 32, 36_000
+DEVICE = 'cuda'
+
+os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:8"
+torch.backends.cudnn.benchmark = False
+torch.backends.cudnn.deterministic = True
+torch.use_deterministic_algorithms(mode=True)
+SEED = 1996
+
+
+class ConformerEquivalenceTest(parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ dict(testcase_name='p=None', dropout_rate=None),
+ dict(testcase_name='p=0.0', dropout_rate=0.0),
+ dict(testcase_name='p=0.2', dropout_rate=0.2),
+ dict(testcase_name='p=0.7', dropout_rate=0.7),
+ dict(testcase_name='p=1.0', dropout_rate=1.0),
+ )
+ def test_forward(self, dropout_rate):
+
+ # Test initalizing custom model with a None dropout_rate
+ for custom_init_dropout_rate in [None, dropout_rate]:
+
+ torch.manual_seed(SEED)
+ orig = OriginalConf(
+ ConformerConfig(
+ num_encoder_layers=3,
+ attention_residual_dropout_rate=dropout_rate,
+ conv_residual_dropout_rate=dropout_rate,
+ feed_forward_residual_dropout_rate=dropout_rate,
+ input_dropout_rate=dropout_rate,
+ )).to(DEVICE)
+
+ torch.manual_seed(SEED)
+ cust = CustomConf(
+ ConformerConfig(
+ num_encoder_layers=3,
+ attention_residual_dropout_rate=custom_init_dropout_rate,
+ conv_residual_dropout_rate=custom_init_dropout_rate,
+ feed_forward_residual_dropout_rate=custom_init_dropout_rate,
+ input_dropout_rate=custom_init_dropout_rate
+ )).to(DEVICE)
+
+ orig.load_state_dict(cust.state_dict()) # sync weights
+
+ x = torch.randn(B, T, device=DEVICE)
+ paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE)
+
+ for mode in ('train', 'eval'):
+ getattr(orig, mode)()
+ getattr(cust, mode)()
+
+ torch.manual_seed(SEED)
+ y1, p1 = orig(x, paddings)
+
+ torch.manual_seed(SEED)
+ y2, p2 = cust(x, paddings, dropout_rate=dropout_rate)
+
+ assert_close(y1, y2, atol=0, rtol=0)
+ assert_close(p1, p2, atol=0, rtol=0)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py b/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py
new file mode 100644
index 000000000..e31f4a7eb
--- /dev/null
+++ b/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py
@@ -0,0 +1,89 @@
+"""
+Runs fwd pass with random input for LIBRISPEECH Deepspeech models and compares outputs.
+Run with:
+ python3 tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py
+
+`dropout_rate` controls the following args:
+- `input_dropout_rate` (if None, 0.1
+- `feed_forward_dropout_rate` (if None, 0.1)
+"""
+
+from absl.testing import absltest, parameterized
+from torch.testing import assert_close
+import torch
+import os
+
+from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import (
+ DeepspeechEncoderDecoder as OriginalModel
+)
+from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models_dropout import(
+ DeepspeechEncoderDecoder as CustomModel,
+ DeepspeechConfig,
+)
+
+B, T = 32, 30_000
+DEVICE = 'cuda'
+TORCH_COMPILE = True
+
+os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:8"
+torch.backends.cudnn.benchmark = False
+torch.backends.cudnn.deterministic = True
+torch.use_deterministic_algorithms(mode=True)
+SEED = 1996
+
+
+class DeepSpeechEquivalenceTest(parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ dict(testcase_name='p=None', dropout_rate=None),
+ dict(testcase_name='p=0.0', dropout_rate=0.0),
+ dict(testcase_name='p=0.2', dropout_rate=0.2),
+ dict(testcase_name='p=0.7', dropout_rate=0.7),
+ dict(testcase_name='p=1.0', dropout_rate=1.0),
+ )
+ def test_forward(self, dropout_rate):
+
+ # Test initalizing custom model with a None dropout_rate
+ for custom_init_dropout_rate in [None, dropout_rate]:
+
+ torch.manual_seed(SEED)
+ orig = OriginalModel(
+ DeepspeechConfig(
+ num_lstm_layers=2,
+ num_ffn_layers=2,
+ input_dropout_rate=dropout_rate,
+ feed_forward_dropout_rate=dropout_rate,
+ )).to(DEVICE)
+
+ torch.manual_seed(SEED)
+ cust = CustomModel(DeepspeechConfig(
+ num_lstm_layers=2,
+ num_ffn_layers=2,
+ input_dropout_rate=custom_init_dropout_rate,
+ feed_forward_dropout_rate=custom_init_dropout_rate,
+ )).to(DEVICE)
+
+ orig.load_state_dict(cust.state_dict()) # sync weights
+
+ if TORCH_COMPILE:
+ orig = torch.compile(orig); cust = torch.compile(cust)
+
+ x = torch.randn(B, T, device=DEVICE)
+ paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE)
+
+ for mode in ('train', 'eval'):
+ getattr(orig, mode)()
+ getattr(cust, mode)()
+
+ torch.manual_seed(SEED)
+ y1, p1 = orig(x, paddings)
+
+ torch.manual_seed(SEED)
+ y2, p2 = cust(x, paddings, dropout_rate=dropout_rate)
+
+ assert_close(y1, y2, atol=0, rtol=0)
+ assert_close(p1, p2, atol=0, rtol=0)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py b/tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py
new file mode 100644
index 000000000..cc1857705
--- /dev/null
+++ b/tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py
@@ -0,0 +1,76 @@
+"""
+Runs fwd pass with random graphs for OGBG GNN models and compares outputs.
+Run with:
+ python3 tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py
+"""
+
+from absl.testing import absltest, parameterized
+from torch.testing import assert_close
+import torch, os, random, numpy as np
+from jraph import GraphsTuple
+
+from algoperf.workloads.ogbg.ogbg_pytorch.models import GNN as OriginalModel
+from algoperf.workloads.ogbg.ogbg_pytorch.models_dropout import GNN as CustomModel
+
+B, N, E = 8, 20, 40 # graphs, nodes/graph, edges/graph
+NODE_FDIM, EDGE_FDIM = 9, 3 # expected feature dims
+DEVICE = 'cuda'
+
+os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
+torch.backends.cudnn.benchmark = False
+torch.backends.cudnn.deterministic = True
+torch.use_deterministic_algorithms(True)
+SEED = 1996
+
+
+def _rand_graph():
+ total_nodes, total_edges = B * N, B * E
+ nodes = torch.randn(total_nodes, NODE_FDIM, device=DEVICE)
+ edges = torch.randn(total_edges, EDGE_FDIM, device=DEVICE)
+ senders, receivers = [], []
+ for i in range(B):
+ offset = i * N
+ s = torch.randint(N, (E,), device=DEVICE) + offset
+ r = torch.randint(N, (E,), device=DEVICE) + offset
+ senders.append(s), receivers.append(r)
+ senders = torch.cat(senders); receivers = torch.cat(receivers)
+ n_node = torch.full((B,), N, device=DEVICE, dtype=torch.int32)
+ n_edge = torch.full((B,), E, device=DEVICE, dtype=torch.int32)
+ return GraphsTuple(nodes, edges, receivers, senders, None, n_node, n_edge)
+
+
+class GNNEquivalenceTest(parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ dict(testcase_name='None', dropout_rate=None),
+ dict(testcase_name='0.0', dropout_rate=0.0),
+ dict(testcase_name='0.2', dropout_rate=0.2),
+ dict(testcase_name='0.7', dropout_rate=0.7),
+ dict(testcase_name='1.0', dropout_rate=1.0),
+ )
+ def test_forward(self, dropout_rate):
+
+ # Test initalizing custom model with a None dropout_rate
+ for custom_init_dropout_rate in [None, dropout_rate]:
+
+ orig = OriginalModel(dropout_rate=dropout_rate).to(DEVICE)
+ cust = CustomModel(dropout_rate=custom_init_dropout_rate).to(DEVICE)
+ orig.load_state_dict(cust.state_dict()) # sync weights
+
+ graph = _rand_graph()
+
+ for mode in ('train', 'eval'):
+ getattr(orig, mode)()
+ getattr(cust, mode)()
+
+ torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED)
+ y1 = orig(graph)
+
+ torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED)
+ y2 = cust(graph, dropout_rate=dropout_rate)
+
+ assert_close(y1, y2, atol=0, rtol=0)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/tests/dropout_fix/wmt_pytorch/test_model_equivalence.py b/tests/dropout_fix/wmt_pytorch/test_model_equivalence.py
new file mode 100644
index 000000000..9aca717d9
--- /dev/null
+++ b/tests/dropout_fix/wmt_pytorch/test_model_equivalence.py
@@ -0,0 +1,83 @@
+"""
+Runs fwd pass with random input for WMT Transformer models and compares outputs.
+Run with:
+ python3 tests/dropout_fix/wmt_pytorch/test_model_equivalence.py
+"""
+
+from absl.testing import absltest, parameterized
+from torch.testing import assert_close
+import torch, os, random, numpy as np
+
+from algoperf.workloads.wmt.wmt_pytorch.models import (
+ Transformer as OriginalModel,
+)
+from algoperf.workloads.wmt.wmt_pytorch.models_dropout import (
+ Transformer as CustomModel,
+)
+
+B, SRC_LEN, TGT_LEN, NTOK = 16, 80, 80, 32_000
+DEVICE = "cuda"
+SEED = 1996
+
+os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
+torch.backends.cudnn.benchmark = False
+torch.backends.cudnn.deterministic = True
+torch.use_deterministic_algorithms(True)
+
+
+def _rand_tokens(bs, seqlen):
+ return torch.randint(1, NTOK, (bs, seqlen), device=DEVICE)
+
+
+class TransformerEquivalenceTest(parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ # NOTE: removed dropout=1.0 will generate nan in scaled_dot_product_attention
+
+ dict(testcase_name="None", dropout_rate=None, compile=False),
+ dict(testcase_name="0.0", dropout_rate=0.0, compile=False),
+ dict(testcase_name="0.2", dropout_rate=0.2, compile=False),
+ dict(testcase_name="0.7", dropout_rate=0.7, compile=False),
+
+ dict(testcase_name="p=None, compile", dropout_rate=None, compile=True),
+ dict(testcase_name="p=0.0, compile", dropout_rate=0.0, compile=True),
+ dict(testcase_name="p=0.2, compile", dropout_rate=0.2, compile=True),
+ dict(testcase_name="p=0.7, compile", dropout_rate=0.7, compile=True),
+ )
+ def test_forward(self, dropout_rate, compile):
+
+ # Test initalizing custom model with a None dropout_rate
+ for custom_init_dropout_rate in [None, dropout_rate]:
+
+ orig = OriginalModel(
+ dropout_rate=dropout_rate,
+ attention_dropout_rate=dropout_rate
+ ).to(DEVICE)
+ cust = CustomModel(
+ dropout_rate=custom_init_dropout_rate
+ ).to(DEVICE)
+
+ orig.load_state_dict(cust.state_dict()) # sync weights
+
+ if compile:
+ orig = torch.compile(orig)
+ cust = torch.compile(cust)
+
+ src = _rand_tokens(B, SRC_LEN)
+ tgt = _rand_tokens(B, TGT_LEN)
+
+ for mode in ("train", "eval"):
+ getattr(orig, mode)()
+ getattr(cust, mode)()
+
+ torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED)
+ y1 = orig(src, tgt)
+
+ torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED)
+ y2 = cust(src, tgt, dropout_rate=dropout_rate)
+
+ assert_close(y1, y2, atol=0, rtol=0)
+
+
+if __name__ == "__main__":
+ absltest.main()
From a7ff3d1ab09c57807f4d0c7b219803407c085c69 Mon Sep 17 00:00:00 2001
From: Niccolo-Ajroldi
Date: Wed, 11 Jun 2025 17:39:46 +0200
Subject: [PATCH 038/123] moved custom dropout to pytorch_utils
---
algoperf/pytorch_utils.py | 38 ++++++++++++++++++
.../criteo1tb_pytorch/models_dropout.py | 2 +-
algoperf/workloads/dropout_modules.py | 40 -------------------
.../fastmri/fastmri_pytorch/models_dropout.py | 2 +-
.../ogbg/ogbg_pytorch/models_dropout.py | 2 +-
5 files changed, 41 insertions(+), 43 deletions(-)
delete mode 100644 algoperf/workloads/dropout_modules.py
diff --git a/algoperf/pytorch_utils.py b/algoperf/pytorch_utils.py
index 4a674985d..4af77088e 100644
--- a/algoperf/pytorch_utils.py
+++ b/algoperf/pytorch_utils.py
@@ -5,7 +5,10 @@
import jax
import tensorflow as tf
import torch
+from torch import Tensor
+import torch.nn as nn
import torch.distributed as dist
+import torch.nn.functional as F
from algoperf import spec
from algoperf.profiler import Profiler
@@ -77,3 +80,38 @@ def update_batch_norm_fn(module: spec.ParameterContainer,
module.momentum = 0.0
elif hasattr(module, 'momentum_backup'):
module.momentum = module.momentum_backup
+
+
+class CustomDropout(nn.Module):
+ """A module around torch.nn.functional.dropout."""
+ def __init__(self):
+ super().__init__()
+ self._supports_custom_dropout = True
+
+ def forward(self, input: Tensor, p: float) -> Tensor:
+ return F.dropout(input, p, training=self.training)
+
+
+class CustomDropout2d(nn.Module):
+ """A module around torch.nn.functional.dropout2d."""
+ def __init__(self):
+ super().__init__()
+ self._supports_custom_dropout = True
+
+ def forward(self, input: Tensor, p: float) -> Tensor:
+ return F.dropout2d(input, p, training=self.training)
+
+
+class SequentialWithDropout(nn.Sequential):
+ """Sequential of modules with dropout."""
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._supports_custom_dropout = True
+
+ def forward(self, x: Tensor, p: float) -> Tensor:
+ for module in self:
+ if getattr(module, '_supports_custom_dropout', False):
+ x = module(x, p)
+ else:
+ x = module(x)
+ return x
diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py
index d8d7393e4..065ebd1f8 100644
--- a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py
+++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py
@@ -6,7 +6,7 @@
import torch.nn.functional as F
from torch import nn
-from algoperf.workloads.dropout_modules import CustomDropout, SequentialWithDropout
+from algoperf.pytorch_utils import CustomDropout, SequentialWithDropout
class DenseBlock(nn.Module):
diff --git a/algoperf/workloads/dropout_modules.py b/algoperf/workloads/dropout_modules.py
deleted file mode 100644
index 6cec3f7ad..000000000
--- a/algoperf/workloads/dropout_modules.py
+++ /dev/null
@@ -1,40 +0,0 @@
-"""Custom classes to support a dynamic modulized dropout, see issue??TODO"""
-
-from torch import Tensor
-from torch import nn
-import torch.nn.functional as F
-
-
-class CustomDropout(nn.Module):
- """A module around torch.nn.functional.dropout."""
- def __init__(self):
- super().__init__()
- self._supports_custom_dropout = True
-
- def forward(self, input: Tensor, p: float) -> Tensor:
- return F.dropout(input, p, training=self.training)
-
-
-class CustomDropout2d(nn.Module):
- """A module around torch.nn.functional.dropout2d."""
- def __init__(self):
- super().__init__()
- self._supports_custom_dropout = True
-
- def forward(self, input: Tensor, p: float) -> Tensor:
- return F.dropout2d(input, p, training=self.training)
-
-
-class SequentialWithDropout(nn.Sequential):
- """Sequential of modules with dropout."""
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self._supports_custom_dropout = True
-
- def forward(self, x: Tensor, p: float) -> Tensor:
- for module in self:
- if getattr(module, '_supports_custom_dropout', False):
- x = module(x, p)
- else:
- x = module(x)
- return x
diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py b/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py
index 8954cb737..260cb7e44 100644
--- a/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py
+++ b/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py
@@ -13,7 +13,7 @@
from torch.nn import functional as F
from algoperf import init_utils
-from algoperf.workloads.dropout_modules import CustomDropout2d, SequentialWithDropout
+from algoperf.pytorch_utils import CustomDropout2d, SequentialWithDropout
diff --git a/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py b/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py
index 1d89ea9e7..b86b88caa 100644
--- a/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py
+++ b/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py
@@ -9,7 +9,7 @@
from torch import nn
from algoperf import init_utils
-from algoperf.workloads.dropout_modules import CustomDropout, SequentialWithDropout
+from algoperf.pytorch_utils import CustomDropout, SequentialWithDropout
def _make_mlp(in_dim, hidden_dims, activation_fn):
From f26ab02d987a839c481dc74d3d15c1d920165d38 Mon Sep 17 00:00:00 2001
From: Niccolo-Ajroldi
Date: Wed, 11 Jun 2025 18:14:26 +0200
Subject: [PATCH 039/123] remove aux_dropout from pytorch workloads
---
.../workloads/criteo1tb/criteo1tb_pytorch/workload.py | 4 +---
.../workloads/fastmri/fastmri_pytorch/workload.py | 4 +---
.../imagenet_resnet/imagenet_pytorch/workload.py | 4 +---
.../imagenet_vit/imagenet_pytorch/workload.py | 4 +---
.../librispeech_pytorch/workload.py | 11 +++--------
.../librispeech_pytorch/workload.py | 11 +++--------
algoperf/workloads/ogbg/ogbg_pytorch/workload.py | 5 +----
algoperf/workloads/wmt/wmt_pytorch/workload.py | 6 ++----
8 files changed, 13 insertions(+), 36 deletions(-)
diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py
index 726aa8705..638022a5e 100644
--- a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py
+++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py
@@ -68,10 +68,8 @@ def loss_fn(
def init_model_fn(
self,
rng: spec.RandomState,
- dropout_rate: Optional[float] = None,
- aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
+ dropout_rate: Optional[float] = None) -> spec.ModelInitState:
"""Only dropout is used."""
- del aux_dropout_rate
torch.random.manual_seed(rng[0])
# Disable cudnn benchmark to avoid OOM errors.
torch.backends.cudnn.benchmark = False
diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/workload.py b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py
index 58943de2f..9582325e1 100644
--- a/algoperf/workloads/fastmri/fastmri_pytorch/workload.py
+++ b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py
@@ -108,9 +108,7 @@ def _build_input_queue(self,
def init_model_fn(
self,
rng: spec.RandomState,
- dropout_rate: Optional[float] = None,
- aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
- del aux_dropout_rate
+ dropout_rate: Optional[float] = None) -> spec.ModelInitState:
torch.random.manual_seed(rng[0])
model = UNet(
num_pool_layers=self.num_pool_layers,
diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py
index ed29271f3..372cac7fa 100644
--- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py
+++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py
@@ -157,11 +157,9 @@ def _build_dataset(
def init_model_fn(
self,
rng: spec.RandomState,
- dropout_rate: Optional[float] = None,
- aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
+ dropout_rate: Optional[float] = None) -> spec.ModelInitState:
"""Dropout is unused."""
del dropout_rate
- del aux_dropout_rate
torch.random.manual_seed(rng[0])
if self.use_silu and self.use_gelu:
diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py
index 97bb38515..1a6bb1381 100644
--- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py
+++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py
@@ -24,9 +24,7 @@ class ImagenetVitWorkload(BaseImagenetVitWorkload, ImagenetResNetWorkload):
def init_model_fn(
self,
rng: spec.RandomState,
- dropout_rate: Optional[float] = None,
- aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
- del aux_dropout_rate
+ dropout_rate: Optional[float] = None) -> spec.ModelInitState:
torch.random.manual_seed(rng[0])
model = models.ViT(
dropout_rate=dropout_rate,
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py
index 5ed37957e..39f33f4aa 100644
--- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py
+++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py
@@ -64,13 +64,8 @@ def attention_temperature(self) -> float:
def init_model_fn(
self,
rng: spec.RandomState,
- dropout_rate: Optional[float] = None,
- aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
- """Conformer model init function.
-
- Here we use dropout_rate as residual_dropout_rate, and aux_dropout_rate as
- input_dropout_rate.
- """
+ dropout_rate: Optional[float] = None) -> spec.ModelInitState:
+ """Conformer model init function."""
torch.random.manual_seed(rng[0])
# Configure torch backends to avoid OOM errors.
torch.backends.cudnn.benchmark = False
@@ -86,7 +81,7 @@ def init_model_fn(
attention_residual_dropout_rate=dropout_rate,
feed_forward_residual_dropout_rate=dropout_rate,
conv_residual_dropout_rate=dropout_rate,
- input_dropout_rate=aux_dropout_rate,
+ input_dropout_rate=dropout_rate,
use_specaug=self.use_specaug,
attention_temperature=self.attention_temperature,
use_post_layer_norm=self.use_post_layer_norm,
diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py
index e5387f5cb..932ba9392 100644
--- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py
+++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py
@@ -25,19 +25,14 @@ class LibriSpeechDeepSpeechWorkload(LibriSpeechConformerWorkload):
def init_model_fn(
self,
rng: spec.RandomState,
- dropout_rate: Optional[float] = None,
- aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
- """Deepspeech model init function.
-
- Here we use dropout_rate as feed_forward_dropout_rate, and aux_dropout_rate
- as input_dropout_rate.
- """
+ dropout_rate: Optional[float] = None) -> spec.ModelInitState:
+ """Deepspeech model init function."""
torch.random.manual_seed(rng[0])
model = DeepspeechEncoderDecoder(
DeepspeechConfig(
feed_forward_dropout_rate=dropout_rate,
use_specaug=self.use_specaug,
- input_dropout_rate=aux_dropout_rate,
+ input_dropout_rate=dropout_rate,
use_tanh=self.use_tanh,
enable_residual_connections=self.enable_residual_connections,
enable_decoder_layer_norm=self.enable_decoder_layer_norm,
diff --git a/algoperf/workloads/ogbg/ogbg_pytorch/workload.py b/algoperf/workloads/ogbg/ogbg_pytorch/workload.py
index 45295ac7f..1dd85951d 100644
--- a/algoperf/workloads/ogbg/ogbg_pytorch/workload.py
+++ b/algoperf/workloads/ogbg/ogbg_pytorch/workload.py
@@ -139,10 +139,7 @@ def _build_input_queue(self,
def init_model_fn(
self,
rng: spec.RandomState,
- dropout_rate: Optional[float] = None,
- aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
- """aux_dropout_rate is unused."""
- del aux_dropout_rate
+ dropout_rate: Optional[float] = None) -> spec.ModelInitState:
torch.random.manual_seed(rng[0])
model = GNN(
num_outputs=self._num_outputs,
diff --git a/algoperf/workloads/wmt/wmt_pytorch/workload.py b/algoperf/workloads/wmt/wmt_pytorch/workload.py
index d0716d6c8..64eea73b7 100644
--- a/algoperf/workloads/wmt/wmt_pytorch/workload.py
+++ b/algoperf/workloads/wmt/wmt_pytorch/workload.py
@@ -168,9 +168,7 @@ def translate_and_calculate_bleu(self,
def init_model_fn(
self,
rng: spec.RandomState,
- dropout_rate: Optional[float] = None,
- aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
- """aux_dropout_rate is used as attention_dropout_rate."""
+ dropout_rate: Optional[float] = None) -> spec.ModelInitState:
torch.random.manual_seed(rng[0])
if self.activation == 'relu':
@@ -182,7 +180,7 @@ def init_model_fn(
model = Transformer(
dropout_rate=dropout_rate,
- attention_dropout_rate=aux_dropout_rate,
+ attention_dropout_rate=dropout_rate,
pre_ln=self.pre_ln,
attention_temp=self.attention_temp,
activation=activation,
From 872393770c6e64f1cf5cb6e6a2b0227c5e88685d Mon Sep 17 00:00:00 2001
From: priyakasimbeg
Date: Wed, 11 Jun 2025 10:22:56 -0700
Subject: [PATCH 040/123] Update submission.py
---
submissions/template/submission.py | 10 ++++------
1 file changed, 4 insertions(+), 6 deletions(-)
diff --git a/submissions/template/submission.py b/submissions/template/submission.py
index a4fdc62b4..2269b7dbb 100644
--- a/submissions/template/submission.py
+++ b/submissions/template/submission.py
@@ -15,9 +15,7 @@ def init_optimizer_state(workload: spec.Workload,
hyperparameters: spec.Hyperparameters,
rng: spec.RandomState) -> spec.OptimizerState:
"""Creates a Nesterov optimizer and a learning rate schedule.
- Returns:
- optimizer state
- optimizer_update_fn
+ Returns: spec.OptimizerState initialized optimizer state
"""
pass
@@ -37,9 +35,9 @@ def update_params(
train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
"""
Returns:
- (new_optimizer_state, update_fn)
- new_params
- new_model_state
+ spec.OptimizerState: new optimizer state
+ spec.ParameterTypeTree: new params
+ new_model_state: new model state
"""
pass
From e0a0e624b7cd70fea51a054d85cd1b418076cdef Mon Sep 17 00:00:00 2001
From: Niccolo-Ajroldi
Date: Thu, 12 Jun 2025 10:57:15 +0200
Subject: [PATCH 041/123] criteo rm dropout from init
---
.../criteo1tb_pytorch/models_dropout.py | 21 +++--------
.../criteo1tb/criteo1tb_pytorch/workload.py | 5 +--
.../test_model_equivalence.py | 35 ++++++++-----------
3 files changed, 20 insertions(+), 41 deletions(-)
diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py
index 065ebd1f8..2ac5c2d1b 100644
--- a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py
+++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py
@@ -8,6 +8,8 @@
from algoperf.pytorch_utils import CustomDropout, SequentialWithDropout
+DEFAULT_DROPOUT_RATE = 0.0
+
class DenseBlock(nn.Module):
"""Dense block with optional residual connection.""" ""
@@ -69,7 +71,6 @@ def __init__(self,
mlp_bottom_dims=(256, 256, 256),
mlp_top_dims=(256, 256, 256, 256, 1),
embed_dim=128,
- dropout_rate=0.0,
use_layer_norm=False,
embedding_init_multiplier=None):
super().__init__()
@@ -79,10 +80,6 @@ def __init__(self,
self.mlp_bottom_dims = mlp_bottom_dims
self.mlp_top_dims = mlp_top_dims
self.embed_dim = embed_dim
- if dropout_rate is None:
- self.dropout_rate = 0.0
- else:
- self.dropout_rate = dropout_rate
# Ideally, we should use the pooled embedding implementation from
# `TorchRec`. However, in order to have identical implementation
@@ -152,9 +149,7 @@ def __init__(self,
0.,
math.sqrt(1. / module.out_features))
- def forward(self, x, dropout_rate=None):
- if dropout_rate is None:
- dropout_rate = self.dropout_rate
+ def forward(self, x, dropout_rate=DEFAULT_DROPOUT_RATE):
batch_size = x.shape[0]
@@ -196,7 +191,6 @@ def __init__(self,
mlp_bottom_dims=(512, 256, 128),
mlp_top_dims=(1024, 1024, 512, 256, 1),
embed_dim=128,
- dropout_rate=0.0,
use_layer_norm=False,
embedding_init_multiplier=None):
super().__init__()
@@ -207,11 +201,6 @@ def __init__(self,
self.mlp_top_dims = mlp_top_dims
self.embed_dim = embed_dim
self.embedding_init_multiplier = embedding_init_multiplier
- self.dropout_rate = dropout_rate
- if dropout_rate is None:
- self.dropout_rate = 0.0
- else:
- self.dropout_rate = dropout_rate
# Ideally, we should use the pooled embedding implementation from
# `TorchRec`. However, in order to have identical implementation
@@ -281,9 +270,7 @@ def __init__(self,
0.,
math.sqrt(1. / module.out_features))
- def forward(self, x, dropout_rate=None):
- if dropout_rate is None:
- dropout_rate = self.dropout_rate
+ def forward(self, x, dropout_rate=DEFAULT_DROPOUT_RATE):
batch_size = x.shape[0]
diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py
index 638022a5e..b128f5bd5 100644
--- a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py
+++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py
@@ -67,9 +67,7 @@ def loss_fn(
def init_model_fn(
self,
- rng: spec.RandomState,
- dropout_rate: Optional[float] = None) -> spec.ModelInitState:
- """Only dropout is used."""
+ rng: spec.RandomState) -> spec.ModelInitState:
torch.random.manual_seed(rng[0])
# Disable cudnn benchmark to avoid OOM errors.
torch.backends.cudnn.benchmark = False
@@ -83,7 +81,6 @@ def init_model_fn(
mlp_bottom_dims=self.mlp_bottom_dims,
mlp_top_dims=self.mlp_top_dims,
embed_dim=self.embed_dim,
- dropout_rate=dropout_rate,
use_layer_norm=self.use_layer_norm,
embedding_init_multiplier=self.embedding_init_multiplier)
self._param_shapes = param_utils.pytorch_param_shapes(model)
diff --git a/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py b/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py
index b9b1232ef..c4f074ff3 100644
--- a/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py
+++ b/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py
@@ -34,8 +34,6 @@
class ModelEquivalenceTest(parameterized.TestCase):
@parameterized.named_parameters(
- dict(testcase_name='DLRMResNet, p=None', model='dlrm_resnet', dropout_rate=None),
- dict(testcase_name='DlrmSmall, p=None', model='dlrm_small', dropout_rate=None),
dict(testcase_name='DLRMResNet, p=0.0', model='dlrm_resnet', dropout_rate=0.0),
dict(testcase_name='DlrmSmall, p=0.0', model='dlrm_small', dropout_rate=0.0),
dict(testcase_name='DLRMResNet, p=0.1', model='dlrm_resnet', dropout_rate=0.1),
@@ -50,27 +48,24 @@ def test_forward(self, model, dropout_rate):
else (OriginalDlrmSmall, CustomDlrmSmall)
)
- # Test initalizing custom model with a None dropout_rate
- for custom_init_dropout_rate in [dropout_rate, None]:
+ torch.manual_seed(SEED)
+ orig = OrigCls(vocab_size=VOCAB, dropout_rate=dropout_rate)
+ orig.to(DEVICE)
- torch.manual_seed(SEED)
- orig = OrigCls(vocab_size=VOCAB, dropout_rate=dropout_rate)
- orig.to(DEVICE)
+ torch.manual_seed(SEED)
+ cust = CustCls(vocab_size=VOCAB)
+ cust.to(DEVICE)
- torch.manual_seed(SEED)
- cust = CustCls(vocab_size=VOCAB, dropout_rate=custom_init_dropout_rate)
- cust.to(DEVICE)
-
- if TORCH_COMPILE:
- orig = torch.compile(orig); cust = torch.compile(cust)
-
- x = torch.randn(BATCH, FEATURES, device=DEVICE)
+ if TORCH_COMPILE:
+ orig = torch.compile(orig); cust = torch.compile(cust)
+
+ x = torch.randn(BATCH, FEATURES, device=DEVICE)
- for mode in ('train', 'eval'):
- getattr(orig, mode)(); getattr(cust, mode)()
- torch.manual_seed(SEED); y1 = orig(x)
- torch.manual_seed(SEED); y2 = cust(x, dropout_rate)
- assert_close(y1, y2, atol=0, rtol=0)
+ for mode in ('train', 'eval'):
+ getattr(orig, mode)(); getattr(cust, mode)()
+ torch.manual_seed(SEED); y1 = orig(x)
+ torch.manual_seed(SEED); y2 = cust(x, dropout_rate)
+ assert_close(y1, y2, atol=0, rtol=0)
if __name__ == '__main__':
From 1e2f379a955bbd75c57042cb6b73750fd3f06eb7 Mon Sep 17 00:00:00 2001
From: Niccolo-Ajroldi
Date: Thu, 12 Jun 2025 11:07:04 +0200
Subject: [PATCH 042/123] criteo rm dropout from init
---
.../criteo1tb_pytorch/test_model_equivalence.py | 11 ++++++++---
1 file changed, 8 insertions(+), 3 deletions(-)
diff --git a/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py b/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py
index c4f074ff3..f40f7b3df 100644
--- a/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py
+++ b/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py
@@ -63,10 +63,15 @@ def test_forward(self, model, dropout_rate):
for mode in ('train', 'eval'):
getattr(orig, mode)(); getattr(cust, mode)()
- torch.manual_seed(SEED); y1 = orig(x)
- torch.manual_seed(SEED); y2 = cust(x, dropout_rate)
+ torch.manual_seed(SEED)
+ y1 = orig(x)
+ torch.manual_seed(SEED)
+ if mode == 'train':
+ y2 = cust(x, dropout_rate)
+ else:
+ y2 = cust(x)
assert_close(y1, y2, atol=0, rtol=0)
-
+
if __name__ == '__main__':
absltest.main()
From f10e3dc6bdebdf87a3460b600c729c63512f7397 Mon Sep 17 00:00:00 2001
From: Niccolo-Ajroldi
Date: Thu, 12 Jun 2025 11:11:26 +0200
Subject: [PATCH 043/123] criteo rm dropout from init
---
.../dropout_fix/criteo1tb_pytorch/test_model_equivalence.py | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
diff --git a/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py b/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py
index f40f7b3df..32aa5b34f 100644
--- a/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py
+++ b/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py
@@ -62,7 +62,9 @@ def test_forward(self, model, dropout_rate):
x = torch.randn(BATCH, FEATURES, device=DEVICE)
for mode in ('train', 'eval'):
- getattr(orig, mode)(); getattr(cust, mode)()
+ getattr(orig, mode)()
+ getattr(cust, mode)()
+
torch.manual_seed(SEED)
y1 = orig(x)
torch.manual_seed(SEED)
@@ -70,6 +72,7 @@ def test_forward(self, model, dropout_rate):
y2 = cust(x, dropout_rate)
else:
y2 = cust(x)
+
assert_close(y1, y2, atol=0, rtol=0)
From 027b053e838a57e5a1c33389d0b8f1b0d4269aa2 Mon Sep 17 00:00:00 2001
From: Niccolo-Ajroldi
Date: Thu, 12 Jun 2025 11:25:31 +0200
Subject: [PATCH 044/123] criteo rm dropout from init
---
.../workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py
index 2ac5c2d1b..b5ee465e2 100644
--- a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py
+++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py
@@ -3,7 +3,6 @@
import math
import torch
-import torch.nn.functional as F
from torch import nn
from algoperf.pytorch_utils import CustomDropout, SequentialWithDropout
@@ -30,7 +29,7 @@ def __init__(self, module, resnet=False):
self.resnet = resnet
self._supports_custom_dropout = True
- def forward(self, x, p=None):
+ def forward(self, x, p):
return self.module(x, p) + x if self.resnet else self.module(x, p)
From 74c43aa20e3c61b488234014816f1bffeec1508d Mon Sep 17 00:00:00 2001
From: Niccolo-Ajroldi
Date: Thu, 12 Jun 2025 11:26:04 +0200
Subject: [PATCH 045/123] fastmri rm dropout from init
---
.../fastmri/fastmri_pytorch/models_dropout.py | 12 ++-----
.../fastmri/fastmri_pytorch/workload.py | 6 ++--
.../fastmri_pytorch/test_model_equivalence.py | 32 ++++++++-----------
3 files changed, 19 insertions(+), 31 deletions(-)
diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py b/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py
index 260cb7e44..73b1d81d9 100644
--- a/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py
+++ b/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py
@@ -15,6 +15,7 @@
from algoperf import init_utils
from algoperf.pytorch_utils import CustomDropout2d, SequentialWithDropout
+DEFAULT_DROPOUT_RATE = 0.0
class UNet(nn.Module):
@@ -29,7 +30,6 @@ def __init__(self,
out_chans: int = 1,
num_channels: int = 32,
num_pool_layers: int = 4,
- dropout_rate: Optional[float] = 0.0,
use_tanh: bool = False,
use_layer_norm: bool = False) -> None:
super().__init__()
@@ -38,10 +38,6 @@ def __init__(self,
self.out_chans = out_chans
self.num_channels = num_channels
self.num_pool_layers = num_pool_layers
- if dropout_rate is None:
- self.dropout_rate = 0.0
- else:
- self.dropout_rate = dropout_rate
self.down_sample_layers = nn.ModuleList([
ConvBlock(in_chans,
@@ -78,9 +74,7 @@ def __init__(self,
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
init_utils.pytorch_default_init(m)
- def forward(self, x: Tensor, dropout_rate: Optional[float] = None) -> Tensor:
- if dropout_rate is None:
- dropout_rate = self.dropout_rate
+ def forward(self, x: Tensor, dropout_rate: Optional[float] = DEFAULT_DROPOUT_RATE) -> Tensor:
stack = []
output = x
@@ -145,7 +139,7 @@ def __init__(self,
CustomDropout2d(),
)
- def forward(self, x: Tensor, dropout_rate: Optional[float] = None) -> Tensor:
+ def forward(self, x: Tensor, dropout_rate: float) -> Tensor:
return self.conv_layers(x, dropout_rate)
diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/workload.py b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py
index 9582325e1..6da0bb0af 100644
--- a/algoperf/workloads/fastmri/fastmri_pytorch/workload.py
+++ b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py
@@ -107,15 +107,13 @@ def _build_input_queue(self,
def init_model_fn(
self,
- rng: spec.RandomState,
- dropout_rate: Optional[float] = None) -> spec.ModelInitState:
+ rng: spec.RandomState) -> spec.ModelInitState:
torch.random.manual_seed(rng[0])
model = UNet(
num_pool_layers=self.num_pool_layers,
num_channels=self.num_channels,
use_tanh=self.use_tanh,
- use_layer_norm=self.use_layer_norm,
- dropout_rate=dropout_rate)
+ use_layer_norm=self.use_layer_norm)
self._param_shapes = param_utils.pytorch_param_shapes(model)
self._param_types = param_utils.pytorch_param_types(self._param_shapes)
model.to(DEVICE)
diff --git a/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py b/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py
index 6339ff21b..c71ff8980 100644
--- a/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py
+++ b/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py
@@ -33,36 +33,32 @@ def fwd_pass(self, orig, cust, dropout_rate):
torch.manual_seed(0)
y1 = orig(x)
torch.manual_seed(0)
- y2 = cust(x, dropout_rate)
+ if mode == 'train':
+ y2 = cust(x, dropout_rate)
+ else:
+ y2 = cust(x)
assert_close(y1, y2, atol=0, rtol=0)
@parameterized.named_parameters(
- dict(testcase_name='p=None', dropout_rate=None),
dict(testcase_name='p=0.0', dropout_rate=0.0),
dict(testcase_name='p=0.1', dropout_rate=0.1),
+ dict(testcase_name='p=0.7', dropout_rate=0.7),
dict(testcase_name='p=1.0', dropout_rate=1.0),
)
def test_dropout_values(self, dropout_rate):
"""Test different values of dropout_rate."""
- # Test initalizing custom model with a None dropout_rate
- for custom_init_dropout_rate in [dropout_rate, None]:
-
- torch.manual_seed(SEED)
- orig = OriginalUNet(
- IN_CHANS, OUT_CHANS, C, LAYERS, dropout_rate=dropout_rate
- ).to(DEVICE)
+ torch.manual_seed(SEED)
+ orig = OriginalUNet(IN_CHANS, OUT_CHANS, C, LAYERS, dropout_rate=dropout_rate).to(DEVICE)
- torch.manual_seed(SEED)
- cust = CustomUNet(
- IN_CHANS, OUT_CHANS, C, LAYERS, dropout_rate=custom_init_dropout_rate
- ).to(DEVICE)
+ torch.manual_seed(SEED)
+ cust = CustomUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE)
- cust.load_state_dict(orig.state_dict()) # sync weights
- if TORCH_COMPILE:
- orig = torch.compile(orig); cust = torch.compile(cust)
+ cust.load_state_dict(orig.state_dict()) # sync weights
+ if TORCH_COMPILE:
+ orig = torch.compile(orig); cust = torch.compile(cust)
- self.fwd_pass(orig, cust, dropout_rate)
+ self.fwd_pass(orig, cust, dropout_rate)
@parameterized.named_parameters(
@@ -71,7 +67,7 @@ def test_dropout_values(self, dropout_rate):
dict(testcase_name='layer_norm', use_tanh=False, use_layer_norm=True),
dict(testcase_name='both', use_tanh=True, use_layer_norm=True),
)
- def test_arch_setups(self, use_tanh, use_layer_norm):
+ def test_arch_configs(self, use_tanh, use_layer_norm):
"""Test different architecture configurations, fixed dropout_rate."""
dropout_rate = 0.1
From 64276ef67cb0ab56bac998c665c82d62e7722087 Mon Sep 17 00:00:00 2001
From: Niccolo-Ajroldi
Date: Thu, 12 Jun 2025 11:44:11 +0200
Subject: [PATCH 046/123] vit rm dropout at init
---
.../imagenet_pytorch/models_dropout.py | 72 +++++++------------
.../wmt/wmt_pytorch/models_dropout.py | 2 +-
.../test_model_equivalence.py | 72 ++++++++++++-------
3 files changed, 72 insertions(+), 74 deletions(-)
diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py
index f5e315fd7..8641847b0 100644
--- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py
+++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py
@@ -14,7 +14,9 @@
from algoperf import init_utils
from algoperf import spec
-from algoperf.workloads.wmt.wmt_pytorch.models import MultiheadAttention
+from algoperf.workloads.wmt.wmt_pytorch.models_dropout import MultiheadAttention
+
+DEFAULT_DROPOUT_RATE = 0.0
def posemb_sincos_2d(patches: spec.Tensor, temperature=10_000.) -> spec.Tensor:
@@ -41,14 +43,12 @@ def __init__(
self,
width: int,
mlp_dim: Optional[int] = None, # Defaults to 4x input dim.
- use_glu: bool = False,
- dropout_rate: float = 0.0) -> None:
+ use_glu: bool = False) -> None:
super().__init__()
self.width = width
self.mlp_dim = mlp_dim or 4 * width
self.use_glu = use_glu
- self.dropout_rate = dropout_rate
self.linear1 = nn.Linear(self.width, self.mlp_dim)
self.act_fnc = nn.GELU(approximate='tanh')
@@ -69,9 +69,7 @@ def reset_parameters(self) -> None:
if module.bias is not None:
module.bias.data.normal_(std=1e-6)
- def forward(self, x: spec.Tensor, dropout_rate=None) -> spec.Tensor:
- if dropout_rate is None:
- dropout_rate = self.dropout_rate
+ def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor:
x = self.linear1(x)
x = self.act_fnc(x)
@@ -90,8 +88,7 @@ class SelfAttention(nn.Module):
def __init__(self,
width: int,
- num_heads: int = 8,
- dropout_rate: float = 0.0) -> None:
+ num_heads: int = 8) -> None:
super().__init__()
self.width = width
@@ -102,7 +99,6 @@ def __init__(self,
self.head_dim = int(width / num_heads)
self.all_head_dim = self.num_heads * self.head_dim
- self.dropout_rate = dropout_rate
self.query = nn.Linear(self.width, self.all_head_dim)
self.key = nn.Linear(self.width, self.all_head_dim)
@@ -122,9 +118,7 @@ def transpose_for_scores(self, x: spec.Tensor) -> spec.Tensor:
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
- def forward(self, x: spec.Tensor, dropout_rate=None) -> spec.Tensor:
- if dropout_rate is None:
- dropout_rate = self.dropout_rate
+ def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor:
mixed_query_layer = self.query(x)
@@ -136,7 +130,7 @@ def forward(self, x: spec.Tensor, dropout_rate=None) -> spec.Tensor:
attention_scores = attention_scores / math.sqrt(self.head_dim)
attention_probs = F.softmax(attention_scores, dim=-1)
- attention_probs = F.dropout(attention_probs, dropout_rate, training=self.training)
+ attention_probs = F.dropout(attention_probs, dropout_rate, self.training)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
@@ -154,8 +148,7 @@ def __init__(self,
mlp_dim: Optional[int] = None,
num_heads: int = 12,
use_glu: bool = False,
- use_post_layer_norm: bool = False,
- dropout_rate: float = 0.0) -> None:
+ use_post_layer_norm: bool = False) -> None:
super().__init__()
self.width = width
@@ -163,7 +156,6 @@ def __init__(self,
self.num_heads = num_heads
self.use_glu = use_glu
self.use_post_layer_norm = use_post_layer_norm
- self.dropout_rate = dropout_rate
self.layer_norm0 = nn.LayerNorm(self.width, eps=1e-6)
self.self_attention1 = SelfAttention(self.width, self.num_heads)
@@ -171,32 +163,29 @@ def __init__(self,
self.mlp3 = MlpBlock(
width=self.width,
mlp_dim=self.mlp_dim,
- use_glu=self.use_glu,
- dropout_rate=dropout_rate)
+ use_glu=self.use_glu)
- def forward(self, x: spec.Tensor, dropout_rate=None) -> spec.Tensor:
- if dropout_rate is None:
- dropout_rate = self.dropout_rate
+ def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor:
if not self.use_post_layer_norm:
y = self.layer_norm0(x)
- y = self.self_attention1(y)
+ y = self.self_attention1(y, dropout_rate)
y = F.dropout(y, dropout_rate, training=self.training)
x = x + y
y = self.layer_norm2(x)
- y = self.mlp3(y)
+ y = self.mlp3(y, dropout_rate)
y = F.dropout(y, dropout_rate, training=self.training)
x = x + y
else:
y = x
- y = self.self_attention1(y)
+ y = self.self_attention1(y, dropout_rate)
y = F.dropout(y, dropout_rate, training=self.training)
x = x + y
x = self.layer_norm0(x)
y = x
- y = self.mlp3(y)
+ y = self.mlp3(y, dropout_rate)
y = F.dropout(y, dropout_rate, training=self.training)
x = x + y
x = self.layer_norm2(x)
@@ -212,8 +201,7 @@ def __init__(self,
mlp_dim: Optional[int] = None,
num_heads: int = 12,
use_glu: bool = False,
- use_post_layer_norm: bool = False,
- dropout_rate: float = 0.0) -> None:
+ use_post_layer_norm: bool = False) -> None:
super().__init__()
self.depth = depth
@@ -228,8 +216,7 @@ def __init__(self,
self.mlp_dim,
self.num_heads,
self.use_glu,
- self.use_post_layer_norm,
- dropout_rate) for _ in range(depth)
+ self.use_post_layer_norm) for _ in range(depth)
])
if not self.use_post_layer_norm:
@@ -237,10 +224,10 @@ def __init__(self,
else:
self.encoder_norm = None
- def forward(self, x: spec.Tensor) -> spec.Tensor:
+ def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor:
# Input Encoder.
for block in self.net:
- x = block(x)
+ x = block(x, dropout_rate)
if not self.use_post_layer_norm:
return self.encoder_norm(x)
else:
@@ -267,13 +254,13 @@ def __init__(self,
self.layer_norm = nn.LayerNorm(self.width, eps=1e-6)
self.mlp = MlpBlock(width=self.width, mlp_dim=self.mlp_dim)
- def forward(self, x: spec.Tensor) -> spec.Tensor:
+ def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor:
n, _, _ = x.shape
probe = torch.tile(self.probe, [n, 1, 1])
- x = self.mha(probe, x)[0]
+ x = self.mha(probe, x, dropout_rate=dropout_rate)[0]
y = self.layer_norm(x)
- x = x + self.mlp(y)
+ x = x + self.mlp(y, dropout_rate)
return x[:, 0]
@@ -293,15 +280,12 @@ def __init__(
mlp_dim: Optional[int] = None, # Defaults to 4x input dim.
num_heads: int = 12,
rep_size: Union[int, bool] = True,
- dropout_rate: Optional[float] = 0.0,
head_zeroinit: bool = True,
use_glu: bool = False,
use_post_layer_norm: bool = False,
use_map: bool = False,
dtype: Any = torch.float32) -> None:
super().__init__()
- if dropout_rate is None:
- dropout_rate = 0.0
self.num_classes = num_classes
self.patch_size = patch_size
@@ -315,7 +299,6 @@ def __init__(
self.use_post_layer_norm = use_post_layer_norm
self.use_map = use_map
self.dtype = dtype
- self.dropout_rate = dropout_rate
if self.rep_size:
rep_size = self.width if self.rep_size is True else self.rep_size
@@ -334,8 +317,7 @@ def __init__(
mlp_dim=self.mlp_dim,
num_heads=self.num_heads,
use_glu=self.use_glu,
- use_post_layer_norm=self.use_post_layer_norm,
- dropout_rate=dropout_rate)
+ use_post_layer_norm=self.use_post_layer_norm)
if self.num_classes:
self.head = nn.Linear(self.width, self.num_classes)
@@ -363,9 +345,7 @@ def reset_parameters(self) -> None:
def get_posemb(self, x: spec.Tensor) -> spec.Tensor:
return posemb_sincos_2d(x).type(self.dtype)
- def forward(self, x: spec.Tensor, dropout_rate=None) -> spec.Tensor:
- if dropout_rate is None:
- dropout_rate = self.dropout_rate
+ def forward(self, x: spec.Tensor, dropout_rate=DEFAULT_DROPOUT_RATE) -> spec.Tensor:
# Patch extraction.
x = self.conv_patch_extract(x)
@@ -379,10 +359,10 @@ def forward(self, x: spec.Tensor, dropout_rate=None) -> spec.Tensor:
x = x + pes
x = F.dropout(x, dropout_rate, training=self.training)
- x = self.encoder(x)
+ x = self.encoder(x, dropout_rate)
if self.use_map:
- x = self.map(x)
+ x = self.map(x, dropout_rate)
else:
x = torch.mean(x, dim=1)
diff --git a/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py b/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py
index c5014d87d..6e265cd7f 100644
--- a/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py
+++ b/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py
@@ -861,7 +861,7 @@ def forward(self,
max_len: Optional[int] = None,
cache: Optional[dict] = None,
index: Optional[int] = None,
- dropout_rate: Optional[float] = None) -> Any:
+ dropout_rate: Optional[float] = None) -> Any: # TODO: (nico) remove default?!
r"""
Args:
x: Batch of input sequences of shape
diff --git a/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py b/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py
index 56644f152..32db2e7d4 100644
--- a/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py
+++ b/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py
@@ -32,41 +32,41 @@ def fwd_pass(self, orig, cust, dropout_rate):
for mode in ('train', 'eval'):
getattr(orig, mode)()
getattr(cust, mode)()
- torch.manual_seed(0); y1 = orig(x)
- torch.manual_seed(0); y2 = cust(x, dropout_rate)
+ torch.manual_seed(0)
+ y1 = orig(x)
+ torch.manual_seed(0)
+ if mode == 'train':
+ y2 = cust(x, dropout_rate)
+ else:
+ y2 = cust(x)
assert_close(y1, y2, atol=0, rtol=0)
@parameterized.named_parameters(
- dict(testcase_name='p=None', dropout_rate=None),
dict(testcase_name='p=0.0', dropout_rate=0.0),
dict(testcase_name='p=0.1', dropout_rate=0.1),
dict(testcase_name='p=0.6', dropout_rate=0.6),
dict(testcase_name='p=1.0', dropout_rate=1.0),
)
def test_dropout_values(self, dropout_rate):
- """Test different dropout_values."""
-
- # Test initalizing custom model with a None dropout_rate
- for custom_init_dropout_rate in [dropout_rate, None]:
-
- torch.manual_seed(SEED)
- orig = OriginalVit(
- width=WIDTH,
- depth=DEPTH,
- num_heads=HEADS,
- dropout_rate=dropout_rate,
- ).to(DEVICE)
-
- torch.manual_seed(SEED)
- cust = CustomVit(
- width=WIDTH,
- depth=DEPTH,
- num_heads=HEADS,
- dropout_rate=custom_init_dropout_rate,
- ).to(DEVICE)
-
- cust.load_state_dict(orig.state_dict()) # sync weights
- self.fwd_pass(orig, cust, dropout_rate)
+ """Test different dropout_rates."""
+
+ torch.manual_seed(SEED)
+ orig = OriginalVit(
+ width=WIDTH,
+ depth=DEPTH,
+ num_heads=HEADS,
+ dropout_rate=dropout_rate,
+ ).to(DEVICE)
+
+ torch.manual_seed(SEED)
+ cust = CustomVit(
+ width=WIDTH,
+ depth=DEPTH,
+ num_heads=HEADS,
+ ).to(DEVICE)
+
+ cust.load_state_dict(orig.state_dict()) # sync weights
+ self.fwd_pass(orig, cust, dropout_rate)
@parameterized.named_parameters([
@@ -101,12 +101,30 @@ def test_arch(self, use_glu, use_post_ln, use_map):
use_glu=use_glu,
use_post_layer_norm=use_post_ln,
use_map=use_map,
- dropout_rate=None,
).to(DEVICE)
cust.load_state_dict(orig.state_dict()) # sync weights
self.fwd_pass(orig, cust, dropout_rate)
+ @parameterized.named_parameters(
+ dict(testcase_name=''),
+ )
+ def test_default_dropout(self):
+ """Test default dropout_rate."""
+
+ torch.manual_seed(SEED)
+ orig = OriginalVit(width=WIDTH, depth=DEPTH, num_heads=HEADS).to(DEVICE)
+ torch.manual_seed(SEED)
+ cust = CustomVit(width=WIDTH, depth=DEPTH, num_heads=HEADS).to(DEVICE)
+ cust.load_state_dict(orig.state_dict()) # sync weights
+
+ x = torch.randn(BATCH, C, H, W, device=DEVICE)
+ for mode in ('train', 'eval'):
+ getattr(orig, mode)(); getattr(cust, mode)()
+ torch.manual_seed(0); y1 = orig(x)
+ torch.manual_seed(0); y2 = cust(x)
+ assert_close(y1, y2, atol=0, rtol=0)
+
if __name__ == '__main__':
absltest.main()
From 44029d2ef1b8ac81556e90e40b119b642c2b4e41 Mon Sep 17 00:00:00 2001
From: Niccolo-Ajroldi
Date: Thu, 12 Jun 2025 11:44:16 +0200
Subject: [PATCH 047/123] vit rm dropout at init
---
algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py | 4 +---
1 file changed, 1 insertion(+), 3 deletions(-)
diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py
index 1a6bb1381..20bd3828b 100644
--- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py
+++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py
@@ -23,11 +23,9 @@ class ImagenetVitWorkload(BaseImagenetVitWorkload, ImagenetResNetWorkload):
def init_model_fn(
self,
- rng: spec.RandomState,
- dropout_rate: Optional[float] = None) -> spec.ModelInitState:
+ rng: spec.RandomState) -> spec.ModelInitState:
torch.random.manual_seed(rng[0])
model = models.ViT(
- dropout_rate=dropout_rate,
num_classes=self._num_classes,
use_glu=self.use_glu,
use_post_layer_norm=self.use_post_layer_norm,
From 44ffec1d6821228eb839a0069b21daeaf32cb08c Mon Sep 17 00:00:00 2001
From: Niccolo-Ajroldi
Date: Thu, 12 Jun 2025 11:47:52 +0200
Subject: [PATCH 048/123] add default dropout test
---
.../test_model_equivalence.py | 33 ++++++++++++++-----
1 file changed, 25 insertions(+), 8 deletions(-)
diff --git a/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py b/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py
index 32aa5b34f..c59331a3d 100644
--- a/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py
+++ b/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py
@@ -23,7 +23,6 @@
FEATURES = DENSE + SPARSE
VOCAB = 1000
DEVICE = 'cuda'
-TORCH_COMPILE = False
SEED = 1996
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
@@ -49,16 +48,11 @@ def test_forward(self, model, dropout_rate):
)
torch.manual_seed(SEED)
- orig = OrigCls(vocab_size=VOCAB, dropout_rate=dropout_rate)
- orig.to(DEVICE)
+ orig = OrigCls(vocab_size=VOCAB, dropout_rate=dropout_rate).to(DEVICE)
torch.manual_seed(SEED)
- cust = CustCls(vocab_size=VOCAB)
- cust.to(DEVICE)
+ cust = CustCls(vocab_size=VOCAB).to(DEVICE)
- if TORCH_COMPILE:
- orig = torch.compile(orig); cust = torch.compile(cust)
-
x = torch.randn(BATCH, FEATURES, device=DEVICE)
for mode in ('train', 'eval'):
@@ -75,6 +69,29 @@ def test_forward(self, model, dropout_rate):
assert_close(y1, y2, atol=0, rtol=0)
+ @parameterized.named_parameters(
+ dict(testcase_name='DLRMResNet, default', model='dlrm_resnet'),
+ dict(testcase_name='DlrmSmall, default', model='dlrm_small'),
+ )
+ def test_default_dropout(self, model):
+ """Test default dropout_rate."""
+ OrigCls, CustCls = (
+ (OriginalDLRMResNet, CustomDLRMResNet)
+ if model == 'dlrm_resnet'
+ else (OriginalDlrmSmall, CustomDlrmSmall)
+ )
+
+ torch.manual_seed(SEED)
+ orig = OrigCls(vocab_size=VOCAB).to(DEVICE)
+ torch.manual_seed(SEED)
+ cust = CustCls(vocab_size=VOCAB).to(DEVICE)
+
+ x = torch.randn(BATCH, FEATURES, device=DEVICE)
+ for mode in ('train', 'eval'):
+ getattr(orig, mode)(); getattr(cust, mode)()
+ torch.manual_seed(0); y1 = orig(x)
+ torch.manual_seed(0); y2 = cust(x)
+ assert_close(y1, y2, atol=0, rtol=0)
if __name__ == '__main__':
absltest.main()
From 9d12fa65b1963ab1b77d1cd8787ddad380bf803a Mon Sep 17 00:00:00 2001
From: Niccolo-Ajroldi
Date: Thu, 12 Jun 2025 11:49:55 +0200
Subject: [PATCH 049/123] add default dropout test
---
.../fastmri_pytorch/test_model_equivalence.py | 20 ++++++++++++++++++-
1 file changed, 19 insertions(+), 1 deletion(-)
diff --git a/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py b/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py
index c71ff8980..6c8ca896c 100644
--- a/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py
+++ b/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py
@@ -15,7 +15,7 @@
BATCH, IN_CHANS, H, W = 4, 1, 256, 256
OUT_CHANS, C, LAYERS = 1, 32, 4
DEVICE = 'cuda'
-TORCH_COMPILE = True
+TORCH_COMPILE = False
SEED = 1996
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
@@ -89,6 +89,24 @@ def test_arch_configs(self, use_tanh, use_layer_norm):
self.fwd_pass(orig, cust, dropout_rate)
+ @parameterized.named_parameters(
+ dict(testcase_name=''),
+ )
+ def test_default_dropout(self):
+ """Test default dropout_rate."""
+
+ torch.manual_seed(SEED)
+ orig = OriginalUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE)
+ torch.manual_seed(SEED)
+ cust = CustomUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE)
+ cust.load_state_dict(orig.state_dict()) # sync weights
+
+ x = torch.randn(BATCH, IN_CHANS, H, W, device=DEVICE)
+ for mode in ('train', 'eval'):
+ getattr(orig, mode)(); getattr(cust, mode)()
+ torch.manual_seed(0); y1 = orig(x)
+ torch.manual_seed(0); y2 = cust(x)
+ assert_close(y1, y2, atol=0, rtol=0)
if __name__ == '__main__':
absltest.main()
From ac45a9fc07aab0b3af28d06ca7540acb06bcc561 Mon Sep 17 00:00:00 2001
From: Niccolo-Ajroldi
Date: Thu, 12 Jun 2025 12:36:14 +0200
Subject: [PATCH 050/123] conformer: rm dropout_rate from init
---
.../librispeech_pytorch/models_dropout.py | 41 ++-----
.../librispeech_pytorch/workload.py | 7 +-
.../test_model_equivalence.py | 106 +++++++++++-------
3 files changed, 74 insertions(+), 80 deletions(-)
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py
index 9ff662fb8..f77c8a814 100644
--- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py
+++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py
@@ -17,6 +17,11 @@
from algoperf.workloads.librispeech_conformer.librispeech_pytorch.spectrum_augmenter import \
SpecAug
+DEFAULT_ATTN_RESIDUAL_DROPOUT_RATE = 0.1
+DEFAULT_CONV_RESIDUAL_DROPOUT_RATE = 0.0
+DEFAULT_FFN_RESIDUAL_DROPOUT_RATE = 0.1
+DEFAULT_INPUT_DROPOUT_RATE = 0.1
+
@dataclass
class ConformerConfig:
@@ -26,13 +31,7 @@ class ConformerConfig:
num_attention_heads: int = 8
num_encoder_layers: int = 4
attention_dropout_rate: float = 0.0
- # If None, defaults to 0.1.
- attention_residual_dropout_rate: Optional[float] = 0.1
- # If None, defaults to 0.0.
- conv_residual_dropout_rate: Optional[float] = 0.0
feed_forward_dropout_rate: float = 0.0
- # If None, defaults to 0.1.
- feed_forward_residual_dropout_rate: Optional[float] = 0.1
convolution_kernel_size: int = 5
feed_forward_expansion_factor: int = 4
freq_mask_count: int = 2
@@ -42,8 +41,6 @@ class ConformerConfig:
time_mask_max_ratio: float = 0.05
time_masks_per_frame: float = 0.0
use_dynamic_time_mask_max_frames: bool = True
- # If None, defaults to 0.1.
- input_dropout_rate: Optional[float] = 0.1
batch_norm_momentum: float = 1 - 0.999
batch_norm_epsilon: float = 0.001
use_specaug: bool = True
@@ -81,11 +78,9 @@ class Subsample(nn.Module):
def __init__(self,
encoder_dim: int = 0,
- input_dropout_rate: float = 0.0,
num_bins: int = 80):
super().__init__()
self.encoder_dim = encoder_dim
- self.input_dropout_rate = input_dropout_rate
self.conv1 = Conv2dSubsampling(
input_channels=1, output_channels=encoder_dim)
@@ -100,7 +95,7 @@ def __init__(self,
def forward(self, inputs, input_paddings, dropout_rate=None):
if dropout_rate is None:
- dropout_rate = self.input_dropout_rate
+ dropout_rate = DEFAULT_INPUT_DROPOUT_RATE
output_paddings = input_paddings
outputs = inputs[:, None, :, :]
@@ -207,14 +202,9 @@ def __init__(self, config: ConformerConfig):
out_features=config.encoder_dim,
bias=True)
- if config.feed_forward_residual_dropout_rate is None:
- self.feed_forward_residual_dropout_rate = 0.1
- else:
- self.feed_forward_residual_dropout_rate = config.feed_forward_residual_dropout_rate
-
def forward(self, inputs, padding_mask, dropout_rate=None):
if dropout_rate is None:
- dropout_rate = self.feed_forward_residual_dropout_rate
+ dropout_rate = DEFAULT_FFN_RESIDUAL_DROPOUT_RATE
inputs = self.ln(inputs)
inputs = self.linear1(inputs)
@@ -319,14 +309,10 @@ def __init__(self, config: ConformerConfig):
self.ln = LayerNorm(dim=config.encoder_dim)
self.self_attention = MHSAwithQS(config)
- if config.attention_residual_dropout_rate is None:
- self.attention_residual_dropout_rate = 0.1
- else:
- self.attention_residual_dropout_rate = config.attention_residual_dropout_rate
def forward(self, outputs, paddings, dropout_rate=None):
if dropout_rate is None:
- dropout_rate = self.attention_residual_dropout_rate
+ dropout_rate = DEFAULT_ATTN_RESIDUAL_DROPOUT_RATE
outputs = self.ln(outputs)
outputs = self.self_attention(
@@ -413,14 +399,10 @@ def __init__(self, config):
groups=config.encoder_dim)
self.bn = BatchNorm(config)
self.lin3 = nn.Linear(config.encoder_dim, config.encoder_dim)
- if config.conv_residual_dropout_rate is None:
- self.conv_residual_dropout_rate = 0.0
- else:
- self.conv_residual_dropout_rate = config.conv_residual_dropout_rate
def forward(self, inputs, input_paddings, dropout_rate=None):
if dropout_rate is None:
- dropout_rate = self.conv_residual_dropout_rate
+ dropout_rate = DEFAULT_CONV_RESIDUAL_DROPOUT_RATE
inputs = self.ln(inputs)
@@ -490,13 +472,8 @@ def __init__(self, config: ConformerConfig):
time_masks_per_frame=config.time_masks_per_frame,
use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames
)
- if config.input_dropout_rate is None:
- input_dropout_rate = 0.1
- else:
- input_dropout_rate = config.input_dropout_rate
self.subsample = Subsample(
encoder_dim=config.encoder_dim,
- input_dropout_rate=input_dropout_rate,
num_bins=preprocessing_config.num_bins)
self.conformers = nn.ModuleList(
[ConformerBlock(config) for _ in range(config.num_encoder_layers)])
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py
index 39f33f4aa..2d0942fe9 100644
--- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py
+++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py
@@ -63,8 +63,7 @@ def attention_temperature(self) -> float:
def init_model_fn(
self,
- rng: spec.RandomState,
- dropout_rate: Optional[float] = None) -> spec.ModelInitState:
+ rng: spec.RandomState) -> spec.ModelInitState:
"""Conformer model init function."""
torch.random.manual_seed(rng[0])
# Configure torch backends to avoid OOM errors.
@@ -78,10 +77,6 @@ def init_model_fn(
activation_function_name = 'swish'
model = models.ConformerEncoderDecoder(
models.ConformerConfig(
- attention_residual_dropout_rate=dropout_rate,
- feed_forward_residual_dropout_rate=dropout_rate,
- conv_residual_dropout_rate=dropout_rate,
- input_dropout_rate=dropout_rate,
use_specaug=self.use_specaug,
attention_temperature=self.attention_temperature,
use_post_layer_norm=self.use_post_layer_norm,
diff --git a/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py b/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py
index 19525a98b..ec8318c9a 100644
--- a/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py
+++ b/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py
@@ -16,14 +16,15 @@
import os
from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import (
- # ConformerConfig,
- ConformerEncoderDecoder as OriginalConf
+ ConformerConfig as OriginalConfig,
+ ConformerEncoderDecoder as OriginalModel
)
from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models_dropout import(
- ConformerEncoderDecoder as CustomConf,
- ConformerConfig,
+ ConformerConfig as CustomConfig,
+ ConformerEncoderDecoder as CustomModel,
)
+N_LAYERS = 3
B, T = 32, 36_000
DEVICE = 'cuda'
@@ -37,55 +38,76 @@
class ConformerEquivalenceTest(parameterized.TestCase):
@parameterized.named_parameters(
- dict(testcase_name='p=None', dropout_rate=None),
dict(testcase_name='p=0.0', dropout_rate=0.0),
dict(testcase_name='p=0.2', dropout_rate=0.2),
dict(testcase_name='p=0.7', dropout_rate=0.7),
dict(testcase_name='p=1.0', dropout_rate=1.0),
)
def test_forward(self, dropout_rate):
-
- # Test initalizing custom model with a None dropout_rate
- for custom_init_dropout_rate in [None, dropout_rate]:
+
+ torch.manual_seed(SEED)
+ orig = OriginalModel(
+ OriginalConfig(
+ num_encoder_layers=N_LAYERS,
+ attention_residual_dropout_rate=dropout_rate,
+ conv_residual_dropout_rate=dropout_rate,
+ feed_forward_residual_dropout_rate=dropout_rate,
+ input_dropout_rate=dropout_rate,
+ )).to(DEVICE)
+
+ torch.manual_seed(SEED)
+ cust = CustomModel(
+ CustomConfig(
+ num_encoder_layers=N_LAYERS
+ )
+ ).to(DEVICE)
+
+ orig.load_state_dict(cust.state_dict()) # sync weights
+
+ x = torch.randn(B, T, device=DEVICE)
+ paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE)
+
+ for mode in ('train', 'eval'):
+ getattr(orig, mode)()
+ getattr(cust, mode)()
torch.manual_seed(SEED)
- orig = OriginalConf(
- ConformerConfig(
- num_encoder_layers=3,
- attention_residual_dropout_rate=dropout_rate,
- conv_residual_dropout_rate=dropout_rate,
- feed_forward_residual_dropout_rate=dropout_rate,
- input_dropout_rate=dropout_rate,
- )).to(DEVICE)
-
+ y1, p1 = orig(x, paddings)
torch.manual_seed(SEED)
- cust = CustomConf(
- ConformerConfig(
- num_encoder_layers=3,
- attention_residual_dropout_rate=custom_init_dropout_rate,
- conv_residual_dropout_rate=custom_init_dropout_rate,
- feed_forward_residual_dropout_rate=custom_init_dropout_rate,
- input_dropout_rate=custom_init_dropout_rate
- )).to(DEVICE)
-
- orig.load_state_dict(cust.state_dict()) # sync weights
-
- x = torch.randn(B, T, device=DEVICE)
- paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE)
-
- for mode in ('train', 'eval'):
- getattr(orig, mode)()
- getattr(cust, mode)()
-
- torch.manual_seed(SEED)
- y1, p1 = orig(x, paddings)
-
- torch.manual_seed(SEED)
+ if mode == 'train':
y2, p2 = cust(x, paddings, dropout_rate=dropout_rate)
-
- assert_close(y1, y2, atol=0, rtol=0)
- assert_close(p1, p2, atol=0, rtol=0)
+ else:
+ y2, p2 = cust(x, paddings)
+
+ assert_close(y1, y2, atol=0, rtol=0)
+ assert_close(p1, p2, atol=0, rtol=0)
+
+ @parameterized.named_parameters(
+ dict(testcase_name=''),
+ )
+ def test_default_dropout(self):
+ """Test default dropout_rate."""
+
+ torch.manual_seed(SEED)
+ orig = OriginalModel(OriginalConfig(num_encoder_layers=N_LAYERS)).to(DEVICE)
+ torch.manual_seed(SEED)
+ cust = CustomModel(CustomConfig(num_encoder_layers=N_LAYERS)).to(DEVICE)
+ orig.load_state_dict(cust.state_dict())
+
+ x = torch.randn(B, T, device=DEVICE)
+ paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE)
+ for mode in ('train', 'eval'):
+ getattr(orig, mode)()
+ getattr(cust, mode)()
+
+ torch.manual_seed(SEED)
+ y1, p1 = orig(x, paddings)
+ torch.manual_seed(SEED)
+ y2, p2 = cust(x, paddings)
+
+ assert_close(y1, y2, atol=0, rtol=0)
+ assert_close(p1, p2, atol=0, rtol=0)
if __name__ == '__main__':
absltest.main()
From 31d64f6c3228eb410065b212994d3f02883d19ee Mon Sep 17 00:00:00 2001
From: Niccolo-Ajroldi
Date: Thu, 12 Jun 2025 13:36:15 +0200
Subject: [PATCH 051/123] rm dropout_rate at init from all workloads
---
.../fastmri/fastmri_pytorch/models_dropout.py | 5 +-
.../imagenet_pytorch/models_dropout.py | 5 +-
.../librispeech_pytorch/models_dropout.py | 24 +----
.../librispeech_pytorch/workload.py | 5 +-
.../ogbg/ogbg_pytorch/models_dropout.py | 15 +--
.../workloads/ogbg/ogbg_pytorch/workload.py | 4 +-
.../wmt/wmt_pytorch/models_dropout.py | 44 ++++----
.../workloads/wmt/wmt_pytorch/workload.py | 5 +-
.../test_model_equivalence.py | 17 ++-
.../fastmri_pytorch/test_model_equivalence.py | 15 ++-
.../test_model_equivalence.py | 15 ++-
.../test_model_equivalence.py | 12 ++-
.../test_model_equivalence.py | 100 +++++++++++-------
.../ogbg_pytorch/test_model_equivalence.py | 55 +++++++---
.../wmt_pytorch/test_model_equivalence.py | 97 +++++++++++------
15 files changed, 234 insertions(+), 184 deletions(-)
diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py b/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py
index 73b1d81d9..0e59e1436 100644
--- a/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py
+++ b/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py
@@ -74,7 +74,10 @@ def __init__(self,
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
init_utils.pytorch_default_init(m)
- def forward(self, x: Tensor, dropout_rate: Optional[float] = DEFAULT_DROPOUT_RATE) -> Tensor:
+ def forward(
+ self,
+ x: Tensor,
+ dropout_rate: float = DEFAULT_DROPOUT_RATE) -> Tensor:
stack = []
output = x
diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py
index 8641847b0..570cee575 100644
--- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py
+++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py
@@ -345,7 +345,10 @@ def reset_parameters(self) -> None:
def get_posemb(self, x: spec.Tensor) -> spec.Tensor:
return posemb_sincos_2d(x).type(self.dtype)
- def forward(self, x: spec.Tensor, dropout_rate=DEFAULT_DROPOUT_RATE) -> spec.Tensor:
+ def forward(
+ self,
+ x: spec.Tensor,
+ dropout_rate: float = DEFAULT_DROPOUT_RATE) -> spec.Tensor:
# Patch extraction.
x = self.conv_patch_extract(x)
diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py
index 8797aa578..21a4df614 100644
--- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py
+++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py
@@ -17,6 +17,7 @@
SpecAug
USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ
+DEFAULT_DROPOUT_RATE = 0.1
@dataclass
@@ -38,10 +39,6 @@ class DeepspeechConfig:
use_dynamic_time_mask_max_frames: bool = True
batch_norm_momentum: float = 1 - 0.999
batch_norm_epsilon: float = 0.001
- # If None, defaults to 0.1.
- input_dropout_rate: Optional[float] = 0.1
- # If None, defaults to 0.1.
- feed_forward_dropout_rate: Optional[float] = 0.1
enable_residual_connections: bool = True
enable_decoder_layer_norm: bool = True
bidirectional: bool = True
@@ -87,14 +84,7 @@ def __init__(self, config: DeepspeechConfig):
self.lin = nn.LazyLinear(out_features=self.encoder_dim, bias=True)
- if config.input_dropout_rate is None:
- self.input_dropout_rate = 0.1
- else:
- self.input_dropout_rate = config.input_dropout_rate
-
- def forward(self, inputs, input_paddings, dropout_rate=None):
- if dropout_rate is None:
- dropout_rate = self.input_dropout_rate
+ def forward(self, inputs, input_paddings, dropout_rate):
output_paddings = input_paddings
outputs = inputs[:, None, :, :]
@@ -207,14 +197,8 @@ def __init__(self, config: DeepspeechConfig):
batch_norm_momentum=config.batch_norm_momentum,
batch_norm_epsilon=config.batch_norm_epsilon)
self.lin = nn.LazyLinear(out_features=config.encoder_dim, bias=True)
- if config.feed_forward_dropout_rate is None:
- self.feed_forward_dropout_rate = 0.1
- else:
- self.feed_forward_dropout_rate = config.feed_forward_dropout_rate
- def forward(self, inputs, input_paddings, dropout_rate=None):
- if dropout_rate is None:
- dropout_rate = self.feed_forward_dropout_rate
+ def forward(self, inputs, input_paddings, dropout_rate):
padding_mask = (1 - input_paddings)[:, :, None]
if self.config.layernorm_everywhere:
@@ -367,7 +351,7 @@ def __init__(self, config: DeepspeechConfig):
self.lin = nn.Linear(config.encoder_dim, config.vocab_size)
- def forward(self, inputs, input_paddings, dropout_rate=None):
+ def forward(self, inputs, input_paddings, dropout_rate=DEFAULT_DROPOUT_RATE):
outputs = inputs
output_paddings = input_paddings
diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py
index 932ba9392..e6ec4764f 100644
--- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py
+++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py
@@ -24,15 +24,12 @@ class LibriSpeechDeepSpeechWorkload(LibriSpeechConformerWorkload):
def init_model_fn(
self,
- rng: spec.RandomState,
- dropout_rate: Optional[float] = None) -> spec.ModelInitState:
+ rng: spec.RandomState) -> spec.ModelInitState:
"""Deepspeech model init function."""
torch.random.manual_seed(rng[0])
model = DeepspeechEncoderDecoder(
DeepspeechConfig(
- feed_forward_dropout_rate=dropout_rate,
use_specaug=self.use_specaug,
- input_dropout_rate=dropout_rate,
use_tanh=self.use_tanh,
enable_residual_connections=self.enable_residual_connections,
enable_decoder_layer_norm=self.enable_decoder_layer_norm,
diff --git a/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py b/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py
index b86b88caa..c8ed23dda 100644
--- a/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py
+++ b/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py
@@ -11,6 +11,8 @@
from algoperf import init_utils
from algoperf.pytorch_utils import CustomDropout, SequentialWithDropout
+DEFAULT_DROPOUT_RATE = 0.1
+
def _make_mlp(in_dim, hidden_dims, activation_fn):
"""Creates a MLP with specified dimensions."""
@@ -34,7 +36,6 @@ class GNN(nn.Module):
def __init__(self,
num_outputs: int = 128,
- dropout_rate: Optional[float] = 0.1,
activation_fn_name: str = 'relu',
latent_dim: int = 256,
hidden_dims: Tuple[int] = (256,),
@@ -44,8 +45,6 @@ def __init__(self,
self.hidden_dims = hidden_dims
self.num_message_passing_steps = num_message_passing_steps
self.num_outputs = num_outputs
- if dropout_rate is None:
- self.dropout_rate = 0.1
# in_features are specifically chosen for the ogbg workload.
self.node_embedder = nn.Linear(in_features=9, out_features=self.latent_dim)
self.edge_embedder = nn.Linear(in_features=3, out_features=self.latent_dim)
@@ -94,9 +93,10 @@ def __init__(self,
if isinstance(m, nn.Linear):
init_utils.pytorch_default_init(m)
- def forward(self, graph: GraphsTuple, dropout_rate=None) -> torch.Tensor:
- if dropout_rate is None:
- dropout_rate = self.dropout_rate
+ def forward(
+ self,
+ graph: GraphsTuple,
+ dropout_rate: float = DEFAULT_DROPOUT_RATE) -> torch.Tensor:
graph = graph._replace(
globals=torch.zeros([graph.n_node.shape[0], self.num_outputs],
@@ -148,7 +148,7 @@ def __init__(self,
self.update_global_fn = update_global_fn
self._supports_custom_dropout = True # supports SequentialWithDropout
- def forward(self, graph: GraphsTuple, dropout_rate=None) -> GraphsTuple:
+ def forward(self, graph: GraphsTuple, dropout_rate: float) -> GraphsTuple:
"""Applies a configured GraphNetwork to a graph.
This implementation follows Algorithm 1 in https://arxiv.org/abs/1806.01261
There is one difference. For the nodes update the class aggregates over the
@@ -161,6 +161,7 @@ def forward(self, graph: GraphsTuple, dropout_rate=None) -> GraphsTuple:
GraphNets, for more information please see the paper.
Args:
graph: a `GraphsTuple` containing the graph.
+ dropout_rate: dropout probability value.
Returns:
Updated `GraphsTuple`.
"""
diff --git a/algoperf/workloads/ogbg/ogbg_pytorch/workload.py b/algoperf/workloads/ogbg/ogbg_pytorch/workload.py
index 1dd85951d..7ead696ce 100644
--- a/algoperf/workloads/ogbg/ogbg_pytorch/workload.py
+++ b/algoperf/workloads/ogbg/ogbg_pytorch/workload.py
@@ -138,12 +138,10 @@ def _build_input_queue(self,
def init_model_fn(
self,
- rng: spec.RandomState,
- dropout_rate: Optional[float] = None) -> spec.ModelInitState:
+ rng: spec.RandomState) -> spec.ModelInitState:
torch.random.manual_seed(rng[0])
model = GNN(
num_outputs=self._num_outputs,
- dropout_rate=dropout_rate,
hidden_dims=self.hidden_dims,
latent_dim=self.latent_dim,
num_message_passing_steps=self.num_message_passing_steps,
diff --git a/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py b/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py
index 6e265cd7f..a5d822669 100644
--- a/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py
+++ b/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py
@@ -9,6 +9,8 @@
from torch.nn.init import normal_
from torch.nn.init import xavier_uniform_
+DEFAULT_DROPOUT_RATE = 0.1
+
def make_causal_mask(x: Tensor, device: str = 'cuda:0') -> Tensor:
"""Make a causal mask for self-attention.
@@ -104,17 +106,12 @@ def __init__(self,
nhead: int = 16,
d_hid: int = 1024,
nlayers: int = 6,
- dropout_rate: Optional[float] = 0.1,
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
glu: bool = False,
layer_norm_eps: float = 1e-6,
attention_temp: float = 1.0,
pre_ln: bool = True):
super().__init__()
- if dropout_rate is None:
- self.dropout_rate = 0.1
- else:
- self.dropout_rate = dropout_rate
self.pos_encoder = PositionalEncoding(d_model)
self.shared_embedding = nn.Embedding(ntoken, d_model)
self.encoder = Encoder(d_model,
@@ -159,7 +156,7 @@ def forward(self,
inputs_segmentation: Optional[Tensor] = None,
targets_segmentation: Optional[Tensor] = None,
decode: bool = False,
- dropout_rate: Optional[float] = None) -> Tensor:
+ dropout_rate: float = DEFAULT_DROPOUT_RATE) -> Tensor:
"""
Args:
src: Tensor, shape [batch_size, seq_len]
@@ -169,7 +166,7 @@ def forward(self,
inputs_segmentation: Optional[Tensor], shape [batch_size, seq_len]
targets_segmentation: Optional[Tensor], shape [batch_size, seq_len]
decode: bool
- dropout_rate: Optional[float]
+ dropout_rate: float
Returns:
output Tensor of shape [batch_size, seq_len, ntoken]
@@ -177,9 +174,6 @@ def forward(self,
if src.size(0) != tgt.size(0):
raise RuntimeError('The batch size of src and tgt must be equal.')
- if dropout_rate is None:
- dropout_rate = self.dropout_rate
-
memory = self.encoder(
src,
inputs_positions=inputs_positions,
@@ -234,13 +228,13 @@ def __init__(self,
def forward(self, src: Tensor,
mask: Optional[Tensor] = None,
- dropout_rate: Optional[float] = None) -> Tensor:
+ dropout_rate: Optional[float] = 0.0) -> Tensor:
"""Pass the input through the encoder layers in turn.
Args:
src: the sequence to the encoder (required).
mask: the mask for the src sequence (optional).
- dropout_rate: the dropout probability (optional)
+ dropout_rate: the dropout probability (optional).
Shape:
see the docs in Transformer class.
@@ -293,7 +287,7 @@ def forward(self,
src: Tensor,
inputs_positions: Optional[Tensor] = None,
inputs_segmentation: Optional[Tensor] = None,
- dropout_rate: Optional[float] = None) -> Tensor:
+ dropout_rate: Optional[float] = 0.0) -> Tensor:
src = src.to(torch.int)
src_mask = make_src_mask(src, inputs_segmentation, self.nhead)
src = self.shared_embedding(src)
@@ -339,7 +333,7 @@ def forward(
decode: bool = False,
max_len: Optional[int] = None,
cache: Optional[dict] = None,
- dropout_rate: Optional[float] = None) -> Any:
+ dropout_rate: Optional[float] = 0.0) -> Any:
tgt = tgt.to(torch.int)
tgt_mask, memory_mask = make_tgt_and_memory_mask(
tgt, src, inputs_segmentation, targets_segmentation,
@@ -389,7 +383,7 @@ def forward(
inputs_positions: Optional[Tensor] = None,
decode: bool = False,
cache: Optional[Dict[str, Dict[str, Tensor]]] = None,
- dropout_rate: Optional[float] = None
+ dropout_rate: Optional[float] = 0.0
) -> Union[Tensor, Tuple[Tensor, Dict[str, Dict[str, Tensor]]]]:
"""
Args:
@@ -397,6 +391,7 @@ def forward(
inputs_positions: Tensor (shape [batch_size, seq_len]) or None
decode: bool
cache: Dict[str, Dict[str, Tensor]] or None
+ dropout_rate: Optional[float]
Returns:
Tensor or Tuple[Tensor, Dict[str, Dict[str, Tensor]]]
"""
@@ -438,7 +433,6 @@ class TransformerEncoderLayer(nn.Module):
nhead: the number of heads in the multiheadattention models (default=16).
dim_feedforward: the dimension of the feedforward network model
(default=1024).
- dropout_rate: the dropout_rate value (default=0.1).
activation: the activation function of the intermediate layer, can be a
string ("relu" or "gelu") or a unary callable (default=F.relu).
layer_norm_eps: the eps value in layer normalization components
@@ -490,7 +484,7 @@ def __init__(self,
def forward(self,
src: Tensor,
src_mask: Optional[Tensor] = None,
- dropout_rate: Optional[float] = None) -> Tensor:
+ dropout_rate: Optional[float] = 0.0) -> Tensor:
r"""Pass the input through the encoder layer.
Args:
@@ -514,14 +508,14 @@ def forward(self,
def _sa_block(self,
x: Tensor,
attn_mask: Optional[Tensor],
- dropout_rate: Optional[float] = None) -> Tensor:
+ dropout_rate: Optional[float] = 0.0) -> Tensor:
x, _ = self.self_attn(x, attn_mask=attn_mask, dropout_rate=dropout_rate)
return F.dropout(x, dropout_rate, training=self.training)
# Feed forward block:
def _ff_block(self,
inputs: Tensor,
- dropout_rate: Optional[float] = None) -> Tensor:
+ dropout_rate: Optional[float] = 0.0) -> Tensor:
x = self.activation(self.linear1(inputs))
if self.glu:
y = self.linear_glu(inputs)
@@ -585,7 +579,7 @@ def forward(self,
decode: bool = False,
max_len: Optional[int] = None,
cache: Optional[dict] = None,
- dropout_rate: Optional[float] = None) -> Any:
+ dropout_rate: Optional[float] = 0.0) -> Any:
r"""Pass the inputs (and mask) through the decoder layer in turn.
Args:
tgt: the sequence to the decoder (required).
@@ -705,7 +699,7 @@ def forward( # pylint: disable=arguments-renamed
max_len: Optional[int] = None,
cache: Optional[dict] = None,
index: Optional[int] = None,
- dropout_rate: Optional[float] = None) -> Any:
+ dropout_rate: Optional[float] = 0.0) -> Any:
r"""Pass the inputs (and mask) through the decoder layer.
Args:
tgt: the sequence to the decoder layer (required).
@@ -757,7 +751,7 @@ def _sa_block( # pylint: disable=arguments-renamed
max_len: Optional[int] = None,
cache: Optional[dict] = None,
index: Optional[int] = None,
- dropout_rate: Optional[float] = None) -> Any:
+ dropout_rate: Optional[float] = 0.0) -> Any:
x, cache = self.self_attn(
x,
attn_mask=attn_mask,
@@ -771,7 +765,7 @@ def _sa_block( # pylint: disable=arguments-renamed
# Multihead attention block:
def _mha_block(self, x: Tensor, mem: Tensor,
attn_mask: Optional[Tensor],
- dropout_rate: Optional[float] = None) -> Tensor:
+ dropout_rate: Optional[float] = 0.0) -> Tensor:
x, _ = self.multihead_attn(
x,
mem,
@@ -781,7 +775,7 @@ def _mha_block(self, x: Tensor, mem: Tensor,
# Feed forward block.
def _ff_block(self, inputs: Tensor,
- dropout_rate: Optional[float] = None) -> Tensor:
+ dropout_rate: Optional[float] = 0.0) -> Tensor:
x = self.activation(self.linear1(inputs))
if self.glu:
y = self.linear_glu(inputs)
@@ -861,7 +855,7 @@ def forward(self,
max_len: Optional[int] = None,
cache: Optional[dict] = None,
index: Optional[int] = None,
- dropout_rate: Optional[float] = None) -> Any: # TODO: (nico) remove default?!
+ dropout_rate: Optional[float] = 0.0) -> Any: # TODO: (nico) remove default?!
r"""
Args:
x: Batch of input sequences of shape
diff --git a/algoperf/workloads/wmt/wmt_pytorch/workload.py b/algoperf/workloads/wmt/wmt_pytorch/workload.py
index 64eea73b7..bb9c3834f 100644
--- a/algoperf/workloads/wmt/wmt_pytorch/workload.py
+++ b/algoperf/workloads/wmt/wmt_pytorch/workload.py
@@ -167,8 +167,7 @@ def translate_and_calculate_bleu(self,
def init_model_fn(
self,
- rng: spec.RandomState,
- dropout_rate: Optional[float] = None) -> spec.ModelInitState:
+ rng: spec.RandomState) -> spec.ModelInitState:
torch.random.manual_seed(rng[0])
if self.activation == 'relu':
@@ -179,8 +178,6 @@ def init_model_fn(
raise ValueError(f'Unknown activation function {self.activation}.')
model = Transformer(
- dropout_rate=dropout_rate,
- attention_dropout_rate=dropout_rate,
pre_ln=self.pre_ln,
attention_temp=self.attention_temp,
activation=activation,
diff --git a/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py b/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py
index c59331a3d..db56b17cf 100644
--- a/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py
+++ b/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py
@@ -56,18 +56,13 @@ def test_forward(self, model, dropout_rate):
x = torch.randn(BATCH, FEATURES, device=DEVICE)
for mode in ('train', 'eval'):
- getattr(orig, mode)()
- getattr(cust, mode)()
-
- torch.manual_seed(SEED)
- y1 = orig(x)
- torch.manual_seed(SEED)
- if mode == 'train':
- y2 = cust(x, dropout_rate)
- else:
- y2 = cust(x)
-
+ getattr(orig, mode)(); getattr(cust, mode)()
+ torch.manual_seed(SEED); y1 = orig(x)
+ torch.manual_seed(SEED); y2 = cust(x, dropout_rate)
assert_close(y1, y2, atol=0, rtol=0)
+ if mode == 'eval': # one extra test: omit dropout at eval
+ torch.manual_seed(SEED); y2 = cust(x)
+ assert_close(y1, y2, atol=0, rtol=0)
@parameterized.named_parameters(
dict(testcase_name='DLRMResNet, default', model='dlrm_resnet'),
diff --git a/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py b/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py
index 6c8ca896c..0d3d52980 100644
--- a/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py
+++ b/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py
@@ -28,16 +28,13 @@ class FastMRIModeEquivalenceTest(parameterized.TestCase):
def fwd_pass(self, orig, cust, dropout_rate):
x = torch.randn(BATCH, IN_CHANS, H, W, device=DEVICE)
for mode in ('train', 'eval'):
- getattr(orig, mode)()
- getattr(cust, mode)()
- torch.manual_seed(0)
- y1 = orig(x)
- torch.manual_seed(0)
- if mode == 'train':
- y2 = cust(x, dropout_rate)
- else:
- y2 = cust(x)
+ getattr(orig, mode)(); getattr(cust, mode)()
+ torch.manual_seed(0); y1 = orig(x)
+ torch.manual_seed(0); y2 = cust(x, dropout_rate)
assert_close(y1, y2, atol=0, rtol=0)
+ if mode == 'eval': # one extra test: omit dropout at eval
+ torch.manual_seed(0); y2 = cust(x)
+ assert_close(y1, y2, atol=0, rtol=0)
@parameterized.named_parameters(
dict(testcase_name='p=0.0', dropout_rate=0.0),
diff --git a/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py b/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py
index 32db2e7d4..d19fad0ba 100644
--- a/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py
+++ b/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py
@@ -30,16 +30,13 @@ class ImageNetVitModeEquivalenceTest(parameterized.TestCase):
def fwd_pass(self, orig, cust, dropout_rate):
x = torch.randn(BATCH, C, H, W, device=DEVICE)
for mode in ('train', 'eval'):
- getattr(orig, mode)()
- getattr(cust, mode)()
- torch.manual_seed(0)
- y1 = orig(x)
- torch.manual_seed(0)
- if mode == 'train':
- y2 = cust(x, dropout_rate)
- else:
- y2 = cust(x)
+ getattr(orig, mode)(); getattr(cust, mode)()
+ torch.manual_seed(0); y1 = orig(x)
+ torch.manual_seed(0); y2 = cust(x, dropout_rate)
assert_close(y1, y2, atol=0, rtol=0)
+ if mode == 'eval': # one extra test: omit dropout at eval
+ torch.manual_seed(0); y2 = cust(x, dropout_rate)
+ assert_close(y1, y2, atol=0, rtol=0)
@parameterized.named_parameters(
dict(testcase_name='p=0.0', dropout_rate=0.0),
diff --git a/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py b/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py
index ec8318c9a..a4238bbc9 100644
--- a/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py
+++ b/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py
@@ -74,14 +74,16 @@ def test_forward(self, dropout_rate):
torch.manual_seed(SEED)
y1, p1 = orig(x, paddings)
torch.manual_seed(SEED)
- if mode == 'train':
- y2, p2 = cust(x, paddings, dropout_rate=dropout_rate)
- else:
- y2, p2 = cust(x, paddings)
-
+ y2, p2 = cust(x, paddings, dropout_rate=dropout_rate)
+
assert_close(y1, y2, atol=0, rtol=0)
assert_close(p1, p2, atol=0, rtol=0)
+ if mode == 'eval': # one extra test: omit dropout at eval
+ torch.manual_seed(SEED)
+ y2, p2 = cust(x, paddings)
+ assert_close(y1, y2, atol=0, rtol=0)
+ assert_close(p1, p2, atol=0, rtol=0)
@parameterized.named_parameters(
dict(testcase_name=''),
diff --git a/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py b/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py
index e31f4a7eb..acdc8c5b3 100644
--- a/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py
+++ b/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py
@@ -14,11 +14,12 @@
import os
from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import (
- DeepspeechEncoderDecoder as OriginalModel
+ DeepspeechEncoderDecoder as OriginalModel,
+ DeepspeechConfig as OriginalConfig
)
from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models_dropout import(
DeepspeechEncoderDecoder as CustomModel,
- DeepspeechConfig,
+ DeepspeechConfig as CustomConfig
)
B, T = 32, 30_000
@@ -35,55 +36,82 @@
class DeepSpeechEquivalenceTest(parameterized.TestCase):
@parameterized.named_parameters(
- dict(testcase_name='p=None', dropout_rate=None),
dict(testcase_name='p=0.0', dropout_rate=0.0),
dict(testcase_name='p=0.2', dropout_rate=0.2),
dict(testcase_name='p=0.7', dropout_rate=0.7),
dict(testcase_name='p=1.0', dropout_rate=1.0),
)
def test_forward(self, dropout_rate):
-
- # Test initalizing custom model with a None dropout_rate
- for custom_init_dropout_rate in [None, dropout_rate]:
+ """Test different dropout_rate values."""
+ torch.manual_seed(SEED)
+ orig = OriginalModel(
+ OriginalConfig(
+ num_lstm_layers=2,
+ num_ffn_layers=2,
+ input_dropout_rate=dropout_rate,
+ feed_forward_dropout_rate=dropout_rate,
+ )).to(DEVICE)
+
+ torch.manual_seed(SEED)
+ cust = CustomModel(CustomConfig(
+ num_lstm_layers=2,
+ num_ffn_layers=2,
+ )).to(DEVICE)
+
+ orig.load_state_dict(cust.state_dict()) # sync weights
+
+ if TORCH_COMPILE:
+ orig = torch.compile(orig); cust = torch.compile(cust)
+
+ x = torch.randn(B, T, device=DEVICE)
+ paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE)
+
+ for mode in ('train', 'eval'):
+ getattr(orig, mode)()
+ getattr(cust, mode)()
+
torch.manual_seed(SEED)
- orig = OriginalModel(
- DeepspeechConfig(
- num_lstm_layers=2,
- num_ffn_layers=2,
- input_dropout_rate=dropout_rate,
- feed_forward_dropout_rate=dropout_rate,
- )).to(DEVICE)
+ y1, p1 = orig(x, paddings)
torch.manual_seed(SEED)
- cust = CustomModel(DeepspeechConfig(
- num_lstm_layers=2,
- num_ffn_layers=2,
- input_dropout_rate=custom_init_dropout_rate,
- feed_forward_dropout_rate=custom_init_dropout_rate,
- )).to(DEVICE)
+ y2, p2 = cust(x, paddings, dropout_rate=dropout_rate)
- orig.load_state_dict(cust.state_dict()) # sync weights
-
- if TORCH_COMPILE:
- orig = torch.compile(orig); cust = torch.compile(cust)
-
- x = torch.randn(B, T, device=DEVICE)
- paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE)
-
- for mode in ('train', 'eval'):
- getattr(orig, mode)()
- getattr(cust, mode)()
-
- torch.manual_seed(SEED)
- y1, p1 = orig(x, paddings)
-
+ assert_close(y1, y2, atol=0, rtol=0)
+ assert_close(p1, p2, atol=0, rtol=0)
+
+ if mode == 'eval': # one extra test: omit dropout at eval
torch.manual_seed(SEED)
- y2, p2 = cust(x, paddings, dropout_rate=dropout_rate)
-
+ y2, p2 = cust(x, paddings)
assert_close(y1, y2, atol=0, rtol=0)
assert_close(p1, p2, atol=0, rtol=0)
+ @parameterized.named_parameters(
+ dict(testcase_name=''),
+ )
+ def test_default_dropout(self):
+ """Test default dropout_rate."""
+
+ torch.manual_seed(SEED)
+ orig = OriginalModel(OriginalConfig( num_lstm_layers=2, num_ffn_layers=2)).to(DEVICE)
+ torch.manual_seed(SEED)
+ cust = CustomModel(CustomConfig( num_lstm_layers=2, num_ffn_layers=2)).to(DEVICE)
+ orig.load_state_dict(cust.state_dict()) # sync weights
+
+ if TORCH_COMPILE:
+ orig = torch.compile(orig); cust = torch.compile(cust)
+
+ x = torch.randn(B, T, device=DEVICE)
+ paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE)
+
+ for mode in ('train', 'eval'):
+ getattr(orig, mode)(); getattr(cust, mode)()
+ torch.manual_seed(SEED); y1, p1 = orig(x, paddings)
+ torch.manual_seed(SEED); y2, p2 = cust(x, paddings)
+ assert_close(y1, y2, atol=0, rtol=0)
+ assert_close(p1, p2, atol=0, rtol=0)
+
+
if __name__ == '__main__':
absltest.main()
diff --git a/tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py b/tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py
index cc1857705..aaca6cebd 100644
--- a/tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py
+++ b/tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py
@@ -42,35 +42,62 @@ def _rand_graph():
class GNNEquivalenceTest(parameterized.TestCase):
@parameterized.named_parameters(
- dict(testcase_name='None', dropout_rate=None),
dict(testcase_name='0.0', dropout_rate=0.0),
dict(testcase_name='0.2', dropout_rate=0.2),
dict(testcase_name='0.7', dropout_rate=0.7),
dict(testcase_name='1.0', dropout_rate=1.0),
)
def test_forward(self, dropout_rate):
+ """Test different dropout_rates."""
- # Test initalizing custom model with a None dropout_rate
- for custom_init_dropout_rate in [None, dropout_rate]:
+ orig = OriginalModel(dropout_rate=dropout_rate).to(DEVICE)
+ cust = CustomModel().to(DEVICE)
+ orig.load_state_dict(cust.state_dict()) # sync weights
- orig = OriginalModel(dropout_rate=dropout_rate).to(DEVICE)
- cust = CustomModel(dropout_rate=custom_init_dropout_rate).to(DEVICE)
- orig.load_state_dict(cust.state_dict()) # sync weights
+ graph = _rand_graph()
- graph = _rand_graph()
+ for mode in ('train', 'eval'):
+ getattr(orig, mode)()
+ getattr(cust, mode)()
- for mode in ('train', 'eval'):
- getattr(orig, mode)()
- getattr(cust, mode)()
+ torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED)
+ y1 = orig(graph)
- torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED)
- y1 = orig(graph)
+ torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED)
+ y2 = cust(graph, dropout_rate=dropout_rate)
- torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED)
- y2 = cust(graph, dropout_rate=dropout_rate)
+ assert_close(y1, y2, atol=0, rtol=0)
+ if mode == 'eval': # one extra test: omit dropout at eval
+ torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED)
+ y2 = cust(graph)
assert_close(y1, y2, atol=0, rtol=0)
+ @parameterized.named_parameters(
+ dict(testcase_name=''),
+ )
+ def test_default_dropout(self):
+ """Test default dropout_rate."""
+
+ orig = OriginalModel().to(DEVICE)
+ cust = CustomModel().to(DEVICE)
+ orig.load_state_dict(cust.state_dict()) # sync weights
+
+ graph = _rand_graph()
+
+ for mode in ('train', 'eval'):
+ getattr(orig, mode)()
+ getattr(cust, mode)()
+
+ torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED)
+ y1 = orig(graph)
+
+ torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED)
+ y2 = cust(graph)
+
+ assert_close(y1, y2, atol=0, rtol=0)
+
+
if __name__ == '__main__':
absltest.main()
diff --git a/tests/dropout_fix/wmt_pytorch/test_model_equivalence.py b/tests/dropout_fix/wmt_pytorch/test_model_equivalence.py
index 9aca717d9..9675f1df2 100644
--- a/tests/dropout_fix/wmt_pytorch/test_model_equivalence.py
+++ b/tests/dropout_fix/wmt_pytorch/test_model_equivalence.py
@@ -32,52 +32,79 @@ def _rand_tokens(bs, seqlen):
class TransformerEquivalenceTest(parameterized.TestCase):
@parameterized.named_parameters(
- # NOTE: removed dropout=1.0 will generate nan in scaled_dot_product_attention
-
- dict(testcase_name="None", dropout_rate=None, compile=False),
+ # NOTE: removed dropout=1.0 since it will generate nan in scaled_dot_product_attention
dict(testcase_name="0.0", dropout_rate=0.0, compile=False),
dict(testcase_name="0.2", dropout_rate=0.2, compile=False),
dict(testcase_name="0.7", dropout_rate=0.7, compile=False),
-
- dict(testcase_name="p=None, compile", dropout_rate=None, compile=True),
- dict(testcase_name="p=0.0, compile", dropout_rate=0.0, compile=True),
- dict(testcase_name="p=0.2, compile", dropout_rate=0.2, compile=True),
- dict(testcase_name="p=0.7, compile", dropout_rate=0.7, compile=True),
+ dict(testcase_name="p=0.0_compile", dropout_rate=0.0, compile=True),
+ dict(testcase_name="p=0.2_compile", dropout_rate=0.2, compile=True),
+ dict(testcase_name="p=0.7_compile", dropout_rate=0.7, compile=True),
)
- def test_forward(self, dropout_rate, compile):
-
- # Test initalizing custom model with a None dropout_rate
- for custom_init_dropout_rate in [None, dropout_rate]:
-
- orig = OriginalModel(
- dropout_rate=dropout_rate,
- attention_dropout_rate=dropout_rate
- ).to(DEVICE)
- cust = CustomModel(
- dropout_rate=custom_init_dropout_rate
- ).to(DEVICE)
-
- orig.load_state_dict(cust.state_dict()) # sync weights
-
- if compile:
- orig = torch.compile(orig)
- cust = torch.compile(cust)
+ def test_dropout_value(self, dropout_rate, compile):
- src = _rand_tokens(B, SRC_LEN)
- tgt = _rand_tokens(B, TGT_LEN)
+ orig = OriginalModel(
+ dropout_rate=dropout_rate,
+ attention_dropout_rate=dropout_rate
+ ).to(DEVICE)
+ cust = CustomModel().to(DEVICE)
+
+ orig.load_state_dict(cust.state_dict()) # sync weights
+
+ if compile:
+ orig = torch.compile(orig)
+ cust = torch.compile(cust)
- for mode in ("train", "eval"):
- getattr(orig, mode)()
- getattr(cust, mode)()
+ src = _rand_tokens(B, SRC_LEN)
+ tgt = _rand_tokens(B, TGT_LEN)
- torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED)
- y1 = orig(src, tgt)
+ for mode in ("train", "eval"):
+ getattr(orig, mode)()
+ getattr(cust, mode)()
+
+ torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED)
+ y1 = orig(src, tgt)
+ torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED)
+ y2 = cust(src, tgt, dropout_rate=dropout_rate)
+
+ assert_close(y1, y2, atol=0, rtol=0)
+
+ if mode == 'eval': # one extra test: omit dropout at eval
torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED)
- y2 = cust(src, tgt, dropout_rate=dropout_rate)
-
+ y2 = cust(src, tgt)
assert_close(y1, y2, atol=0, rtol=0)
+ @parameterized.named_parameters(
+ dict(testcase_name="default", compile=False),
+ dict(testcase_name="default_compile", compile=True),
+ )
+ def test_default(self, compile):
+
+ orig = OriginalModel().to(DEVICE)
+ cust = CustomModel().to(DEVICE)
+
+ orig.load_state_dict(cust.state_dict()) # sync weights
+
+ if compile:
+ orig = torch.compile(orig)
+ cust = torch.compile(cust)
+
+ src = _rand_tokens(B, SRC_LEN)
+ tgt = _rand_tokens(B, TGT_LEN)
+
+ for mode in ("train", "eval"):
+ getattr(orig, mode)()
+ getattr(cust, mode)()
+
+ torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED)
+ y1 = orig(src, tgt)
+
+ torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED)
+ y2 = cust(src, tgt)
+
+ assert_close(y1, y2, atol=0, rtol=0)
+
+
if __name__ == "__main__":
absltest.main()
From 5e192dd6397194d1435631752a1c29289b9a9888 Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Thu, 12 Jun 2025 21:02:11 +0000
Subject: [PATCH 052/123] remove dropout_rate from init_model_fn for all jax
workloads
---
.../workloads/cifar/cifar_jax/workload.py | 6 +--
.../criteo1tb/criteo1tb_jax/workload.py | 34 ++++++-----------
.../workloads/fastmri/fastmri_jax/workload.py | 27 +++++--------
.../imagenet_resnet/imagenet_jax/workload.py | 7 +---
.../imagenet_vit/imagenet_jax/workload.py | 28 +++++---------
.../librispeech_jax/workload.py | 28 +++++---------
.../librispeech_jax/workload.py | 38 ++++++-------------
.../workloads/mnist/mnist_jax/workload.py | 7 +---
algoperf/workloads/ogbg/ogbg_jax/workload.py | 29 +++++---------
algoperf/workloads/wmt/wmt_jax/workload.py | 26 ++++---------
10 files changed, 71 insertions(+), 159 deletions(-)
diff --git a/algoperf/workloads/cifar/cifar_jax/workload.py b/algoperf/workloads/cifar/cifar_jax/workload.py
index ad43bc62f..3f2397f8c 100644
--- a/algoperf/workloads/cifar/cifar_jax/workload.py
+++ b/algoperf/workloads/cifar/cifar_jax/workload.py
@@ -81,12 +81,8 @@ def sync_batch_stats(
def init_model_fn(
self,
- rng: spec.RandomState,
- dropout_rate: Optional[float] = None,
- aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
+ rng: spec.RandomState) -> spec.ModelInitState:
"""Dropout is unused."""
- del dropout_rate
- del aux_dropout_rate
model_cls = getattr(models, 'ResNet18')
model = model_cls(num_classes=self._num_classes, dtype=jnp.float32)
self._model = model
diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py
index 101e02c15..dcb7b9a57 100644
--- a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py
+++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py
@@ -72,7 +72,6 @@ def loss_fn(
def init_model_fn(
self,
rng: spec.RandomState,
- dropout_rate: Optional[float] = None,
tabulate: Optional[bool] = False,
) -> spec.ModelInitState:
"""Only dropout is used."""
@@ -81,27 +80,16 @@ def init_model_fn(
else:
model_class = models.DlrmSmall
- if dropout_rate is None:
- self._model = model_class(
- vocab_size=self.vocab_size,
- num_dense_features=self.num_dense_features,
- mlp_bottom_dims=self.mlp_bottom_dims,
- mlp_top_dims=self.mlp_top_dims,
- embed_dim=self.embed_dim,
- use_layer_norm=self.use_layer_norm,
- embedding_init_multiplier=self.embedding_init_multiplier)
- else:
- self._model = model_class(
- vocab_size=self.vocab_size,
- num_dense_features=self.num_dense_features,
- mlp_bottom_dims=self.mlp_bottom_dims,
- mlp_top_dims=self.mlp_top_dims,
- embed_dim=self.embed_dim,
- dropout_rate=dropout_rate,
- use_layer_norm=self.use_layer_norm,
- embedding_init_multiplier=self.embedding_init_multiplier)
-
- params_rng, dropout_rng = jax.random.split(rng)
+ self._model = model_class(
+ vocab_size=self.vocab_size,
+ num_dense_features=self.num_dense_features,
+ mlp_bottom_dims=self.mlp_bottom_dims,
+ mlp_top_dims=self.mlp_top_dims,
+ embed_dim=self.embed_dim,
+ use_layer_norm=self.use_layer_norm,
+ embedding_init_multiplier=self.embedding_init_multiplier)
+
+ params_rng, _= jax.random.split(rng)
init_fake_batch_size = 2
num_categorical_features = 26
num_dense_features = 13
@@ -109,7 +97,7 @@ def init_model_fn(
input_shape = (init_fake_batch_size, input_size)
init_fn = functools.partial(self._model.init, train=False)
initial_variables = jax.jit(init_fn)(
- {'params': params_rng, 'dropout': dropout_rng},
+ {'params': params_rng,},
jnp.ones(input_shape, jnp.float32))
initial_params = initial_variables['params']
self._param_shapes = param_utils.jax_param_shapes(initial_params)
diff --git a/algoperf/workloads/fastmri/fastmri_jax/workload.py b/algoperf/workloads/fastmri/fastmri_jax/workload.py
index 3d891cf8f..bf0acfc8d 100644
--- a/algoperf/workloads/fastmri/fastmri_jax/workload.py
+++ b/algoperf/workloads/fastmri/fastmri_jax/workload.py
@@ -21,28 +21,19 @@ class FastMRIWorkload(BaseFastMRIWorkload):
def init_model_fn(
self,
rng: spec.RandomState,
- dropout_rate: Optional[float] = None,
) -> spec.ModelInitState:
"""aux_dropout_rate is unused."""
fake_batch = jnp.zeros((13, 320, 320))
- if dropout_rate is None:
- self._model = UNet(
- num_pool_layers=self.num_pool_layers,
- num_channels=self.num_channels,
- use_tanh=self.use_tanh,
- use_layer_norm=self.use_layer_norm,
+ self._model = UNet(
+ num_pool_layers=self.num_pool_layers,
+ num_channels=self.num_channels,
+ use_tanh=self.use_tanh,
+ use_layer_norm=self.use_layer_norm,
)
- else:
- self._model = UNet(
- num_pool_layers=self.num_pool_layers,
- num_channels=self.num_channels,
- use_tanh=self.use_tanh,
- use_layer_norm=self.use_layer_norm,
- dropout_rate=dropout_rate)
-
- params_rng, dropout_rng = jax.random.split(rng)
- variables = jax.jit(
- self._model.init)({'params': params_rng, 'dropout': dropout_rng},
+
+ params_rng, _ = jax.random.split(rng)
+ init_fn = functools.partial(self._model.init, train=False)
+ variables = jax.jit(init_fn)({'params': params_rng},
fake_batch)
params = variables['params']
self._param_shapes = param_utils.jax_param_shapes(params)
diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py
index 4ec3937b8..2a255fee4 100644
--- a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py
+++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py
@@ -84,12 +84,7 @@ def sync_batch_stats(
def init_model_fn(
self,
- rng: spec.RandomState,
- dropout_rate: Optional[float] = None,
- aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
- """Dropout is unused."""
- del dropout_rate
- del aux_dropout_rate
+ rng: spec.RandomState,) -> spec.ModelInitState:
model_cls = getattr(models, 'ResNet50')
if self.use_silu and self.use_gelu:
diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py
index 89355ac6e..b8a870de5 100644
--- a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py
+++ b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py
@@ -23,32 +23,22 @@ class ImagenetVitWorkload(BaseImagenetVitWorkload, ImagenetResNetWorkload):
def initialized(self, key: spec.RandomState,
model: nn.Module) -> spec.ModelInitState:
input_shape = (1, 224, 224, 3)
- params_rng, dropout_rng = jax.random.split(key)
+ params_rng, _ = jax.random.split(key)
variables = jax.jit(
- model.init)({'params': params_rng, 'dropout': dropout_rng},
+ model.init)({'params': params_rng},
jnp.ones(input_shape))
model_state, params = pop(variables, "params")
return params, model_state
def init_model_fn(
self,
- rng: spec.RandomState,
- dropout_rate: Optional[float] = None) -> spec.ModelInitState:
- if dropout_rate is None:
- self._model = models.ViT(
- num_classes=self._num_classes,
- use_glu=self.use_glu,
- use_post_layer_norm=self.use_post_layer_norm,
- use_map=self.use_map,
- **decode_variant('S/16'))
- else:
- self._model = models.ViT(
- dropout_rate=dropout_rate,
- num_classes=self._num_classes,
- use_glu=self.use_glu,
- use_post_layer_norm=self.use_post_layer_norm,
- use_map=self.use_map,
- **decode_variant('S/16'))
+ rng: spec.RandomState) -> spec.ModelInitState:
+ self._model = models.ViT(
+ num_classes=self._num_classes,
+ use_glu=self.use_glu,
+ use_post_layer_norm=self.use_post_layer_norm,
+ use_map=self.use_map,
+ **decode_variant('S/16'))
params, model_state = self.initialized(rng, self._model)
self._param_shapes = param_utils.jax_param_shapes(params)
self._param_types = param_utils.jax_param_types(self._param_shapes)
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py
index 2e082cf07..042dba7f4 100644
--- a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py
+++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py
@@ -60,7 +60,6 @@ def attention_temperature(self) -> float:
def init_model_fn(
self,
rng: spec.RandomState,
- dropout_rate: Optional[float] = None,
) -> spec.ModelInitState:
"""Conformer model init function.
@@ -71,21 +70,14 @@ def init_model_fn(
activation_function_name = 'gelu'
else:
activation_function_name = 'swish'
- if dropout_rate is None:
- model_config = models.ConformerConfig(
- attention_residual_dropout_rate=dropout_rate,
- feed_forward_residual_dropout_rate=dropout_rate,
- input_dropout_rate=dropout_rate,
- use_specaug=self.use_specaug,
- attention_temperature=self.attention_temperature,
- use_post_layer_norm=self.use_post_layer_norm,
- activation_function_name=activation_function_name)
- else:
- model_config = models.ConformerConfig(
- use_specaug=self.use_specaug,
- attention_temperature=self.attention_temperature,
- use_post_layer_norm=self.use_post_layer_norm,
- activation_function_name=activation_function_name)
+ model_config = models.ConformerConfig(
+ attention_residual_dropout_rate=dropout_rate,
+ feed_forward_residual_dropout_rate=dropout_rate,
+ input_dropout_rate=dropout_rate,
+ use_specaug=self.use_specaug,
+ attention_temperature=self.attention_temperature,
+ use_post_layer_norm=self.use_post_layer_norm,
+ activation_function_name=activation_function_name)
self._model = models.Conformer(model_config)
input_shape = [(320000,), (320000,)]
@@ -93,8 +85,8 @@ def init_model_fn(
model_init_fn = jax.jit(functools.partial(self._model.init, train=False))
- params_rng, dropout_rng = jax.random.split(rng, 2)
- variables = model_init_fn({'params': params_rng, 'dropout': dropout_rng},
+ params_rng, _ = jax.random.split(rng, 2)
+ variables = model_init_fn({'params': params_rng},
*fake_input_batch)
model_state, params = pop(variables, "params")
diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py
index 825b470db..2213f189e 100644
--- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py
+++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py
@@ -17,40 +17,26 @@ class LibriSpeechDeepSpeechWorkload(LibriSpeechConformerWorkload):
def init_model_fn(
self,
- rng: spec.RandomState,
- dropout_rate: Optional[float] = None) -> spec.ModelInitState:
+ rng: spec.RandomState) -> spec.ModelInitState:
"""Deepspeech model init function.
"""
- if dropout_rate is None:
- model_config = models.DeepspeechConfig(
- use_specaug=self.use_specaug,
- use_tanh=self.use_tanh,
- enable_residual_connections=self.enable_residual_connections,
- enable_decoder_layer_norm=self.enable_decoder_layer_norm,
- layernorm_everywhere=self.layernorm_everywhere,
- freq_mask_count=self.freq_mask_count,
- time_mask_count=self.time_mask_count,
- )
- else:
- model_config = models.DeepspeechConfig(
- feed_forward_dropout_rate=dropout_rate,
- use_specaug=self.use_specaug,
- input_dropout_rate=dropout_rate,
- use_tanh=self.use_tanh,
- enable_residual_connections=self.enable_residual_connections,
- enable_decoder_layer_norm=self.enable_decoder_layer_norm,
- layernorm_everywhere=self.layernorm_everywhere,
- freq_mask_count=self.freq_mask_count,
- time_mask_count=self.time_mask_count,
- )
+ model_config = models.DeepspeechConfig(
+ use_specaug=self.use_specaug,
+ use_tanh=self.use_tanh,
+ enable_residual_connections=self.enable_residual_connections,
+ enable_decoder_layer_norm=self.enable_decoder_layer_norm,
+ layernorm_everywhere=self.layernorm_everywhere,
+ freq_mask_count=self.freq_mask_count,
+ time_mask_count=self.time_mask_count,
+ )
self._model = models.Deepspeech(model_config)
input_shape = [(320000,), (320000,)]
fake_input_batch = [np.zeros((2, *x), jnp.float32) for x in input_shape]
model_init_fn = jax.jit(functools.partial(self._model.init, train=False))
- params_rng, dropout_rng = jax.random.split(rng, 2)
- variables = model_init_fn({'params': params_rng, 'dropout': dropout_rng},
+ params_rng, _ = jax.random.split(rng, 2)
+ variables = model_init_fn({'params': params_rng,},
*fake_input_batch)
model_state = variables[
diff --git a/algoperf/workloads/mnist/mnist_jax/workload.py b/algoperf/workloads/mnist/mnist_jax/workload.py
index 5a4382da1..5f3fdcf78 100644
--- a/algoperf/workloads/mnist/mnist_jax/workload.py
+++ b/algoperf/workloads/mnist/mnist_jax/workload.py
@@ -34,12 +34,7 @@ class MnistWorkload(BaseMnistWorkload):
def init_model_fn(
self,
- rng: spec.RandomState,
- dropout_rate: Optional[float] = None,
- aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
- """Dropout is unused."""
- del dropout_rate
- del aux_dropout_rate
+ rng: spec.RandomState) -> spec.ModelInitState:
init_val = jnp.ones((1, 28, 28, 1), jnp.float32)
self._model = _Model()
initial_params = self._model.init({'params': rng}, init_val,
diff --git a/algoperf/workloads/ogbg/ogbg_jax/workload.py b/algoperf/workloads/ogbg/ogbg_jax/workload.py
index 3becd5599..aaa5b4064 100644
--- a/algoperf/workloads/ogbg/ogbg_jax/workload.py
+++ b/algoperf/workloads/ogbg/ogbg_jax/workload.py
@@ -19,25 +19,14 @@ class OgbgWorkload(BaseOgbgWorkload):
def init_model_fn(
self,
- rng: spec.RandomState,
- dropout_rate: Optional[float] = None) -> spec.ModelInitState:
- """aux_dropout_rate is unused."""
- rng, params_rng, dropout_rng = jax.random.split(rng, 3)
- if dropout_rate is None:
- self._model = models.GNN(
- self._num_outputs,
- activation_fn_name=self.activation_fn_name,
- hidden_dims=self.hidden_dims,
- latent_dim=self.latent_dim,
- num_message_passing_steps=self.num_message_passing_steps)
- else:
- self._model = models.GNN(
- self._num_outputs,
- dropout_rate=dropout_rate,
- activation_fn_name=self.activation_fn_name,
- hidden_dims=self.hidden_dims,
- latent_dim=self.latent_dim,
- num_message_passing_steps=self.num_message_passing_steps)
+ rng: spec.RandomState) -> spec.ModelInitState:
+ rng, params_rng = jax.random.split(rng, 2)
+ self._model = models.GNN(
+ self._num_outputs,
+ activation_fn_name=self.activation_fn_name,
+ hidden_dims=self.hidden_dims,
+ latent_dim=self.latent_dim,
+ num_message_passing_steps=self.num_message_passing_steps)
init_fn = jax.jit(functools.partial(self._model.init, train=False))
fake_batch = jraph.GraphsTuple(
n_node=jnp.asarray([1]),
@@ -47,7 +36,7 @@ def init_model_fn(
globals=jnp.zeros((1, self._num_outputs)),
senders=jnp.asarray([0]),
receivers=jnp.asarray([0]))
- params = init_fn({'params': params_rng, 'dropout': dropout_rng}, fake_batch)
+ params = init_fn({'params': params_rng}, fake_batch)
params = params['params']
self._param_shapes = param_utils.jax_param_shapes(params)
self._param_types = param_utils.jax_param_types(self._param_shapes)
diff --git a/algoperf/workloads/wmt/wmt_jax/workload.py b/algoperf/workloads/wmt/wmt_jax/workload.py
index 193732640..9e109dc86 100644
--- a/algoperf/workloads/wmt/wmt_jax/workload.py
+++ b/algoperf/workloads/wmt/wmt_jax/workload.py
@@ -208,8 +208,7 @@ def translate_and_calculate_bleu(self,
def init_model_fn(
self,
- rng: spec.RandomState,
- dropout_rate: Optional[float] = 0.0) -> spec.ModelInitState:
+ rng: spec.RandomState) -> spec.ModelInitState:
init_fake_batch_size = 2
input_shape = (init_fake_batch_size, 256)
target_shape = (init_fake_batch_size, 256)
@@ -221,26 +220,17 @@ def init_model_fn(
else:
raise ValueError(f'Unknown activation function {self.activation}.')
- if dropout_rate is None:
- model_config = models.TransformerConfig(
- pre_ln=self.pre_ln,
- attention_temp=self.attention_temp,
- activation=activation,
- glu=self.glu)
- else:
- model_config = models.TransformerConfig(
- dropout_rate=dropout_rate,
- attention_dropout_rate=dropout_rate,
- pre_ln=self.pre_ln,
- attention_temp=self.attention_temp,
- activation=activation,
- glu=self.glu)
+ model_config = models.TransformerConfig(
+ pre_ln=self.pre_ln,
+ attention_temp=self.attention_temp,
+ activation=activation,
+ glu=self.glu)
self._train_model = models.Transformer(model_config)
eval_config = replace(model_config, deterministic=True)
self._eval_model = models.Transformer(eval_config)
- params_rng, dropout_rng = jax.random.split(rng)
+ params_rng, _ = jax.random.split(rng)
initial_variables = jax.jit(
- self._eval_model.init)({'params': params_rng, 'dropout': dropout_rng},
+ self._eval_model.init)({'params': params_rng},
jnp.ones(input_shape, jnp.float32),
jnp.ones(target_shape, jnp.float32))
From 23828cdb0d54207a6714263b4d9f44531011f375 Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Thu, 12 Jun 2025 21:03:11 +0000
Subject: [PATCH 053/123] remove dropout from model initialization call in
submission_runner.py
---
submission_runner.py | 5 +----
1 file changed, 1 insertion(+), 4 deletions(-)
diff --git a/submission_runner.py b/submission_runner.py
index bb4a8c6cc..d076a1043 100644
--- a/submission_runner.py
+++ b/submission_runner.py
@@ -228,11 +228,8 @@ def train_once(
global_batch_size=global_batch_size)
logging.info('Initializing model.')
with profiler.profile('Initializing model'):
- dropout_rate = None
- if hasattr(hyperparameters, 'dropout_rate'):
- dropout_rate = hyperparameters.dropout_rate
model_params, model_state = workload.init_model_fn(
- model_init_rng, dropout_rate)
+ model_init_rng)
if FLAGS.framework == 'pytorch' and FLAGS.torch_compile:
compile_error_workloads = [
'librispeech_conformer',
From 86b86245a3754d59b8e707eecacd3be0477419d7 Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Thu, 12 Jun 2025 21:28:35 +0000
Subject: [PATCH 054/123] remove dropout check for None and use default instead
if not passed
---
.../criteo1tb/criteo1tb_jax/models.py | 11 +++----
.../workloads/fastmri/fastmri_jax/models.py | 12 +++-----
.../imagenet_vit/imagenet_jax/models.py | 26 +++++------------
.../librispeech_jax/models.py | 27 +++++------------
.../librispeech_jax/models.py | 21 ++++----------
algoperf/workloads/ogbg/ogbg_jax/models.py | 7 ++---
algoperf/workloads/wmt/wmt_jax/models.py | 29 +++++++------------
7 files changed, 41 insertions(+), 92 deletions(-)
diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py
index b7af15208..57cb7f2d9 100644
--- a/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py
+++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py
@@ -7,6 +7,7 @@
from algoperf.jax_utils import Dropout
+DROPOUT_RATE = 0.0
class DLRMResNet(nn.Module):
"""Define a DLRMResNet model.
@@ -24,14 +25,12 @@ class DLRMResNet(nn.Module):
mlp_bottom_dims: Sequence[int] = (256, 256, 256)
mlp_top_dims: Sequence[int] = (256, 256, 256, 256, 1)
embed_dim: int = 128
- dropout_rate: float = 0.0
+ dropout_rate: float = DROPOUT_RATE
use_layer_norm: bool = False # Unused.
embedding_init_multiplier: float = None # Unused
@nn.compact
- def __call__(self, x, train, dropout_rate=None):
- if dropout_rate is None:
- dropout_rate = self.dropout_rate
+ def __call__(self, x, train, dropout_rate=DROPOUT_RATE):
bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1)
cat_features = jnp.asarray(cat_features, dtype=jnp.int32)
@@ -155,9 +154,7 @@ class DlrmSmall(nn.Module):
embedding_init_multiplier: float = None
@nn.compact
- def __call__(self, x, train, dropout_rate=None):
- if dropout_rate is None:
- dropout_rate = self.dropout_rate
+ def __call__(self, x, train, dropout_rate=DROPOUT_RATE):
bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1)
cat_features = jnp.asarray(cat_features, dtype=jnp.int32)
diff --git a/algoperf/workloads/fastmri/fastmri_jax/models.py b/algoperf/workloads/fastmri/fastmri_jax/models.py
index b04510297..5850defa7 100644
--- a/algoperf/workloads/fastmri/fastmri_jax/models.py
+++ b/algoperf/workloads/fastmri/fastmri_jax/models.py
@@ -21,6 +21,7 @@
from algoperf.jax_utils import Dropout
+DROPOUT_RATE = 0.0
def _instance_norm2d(x, axes, epsilon=1e-5):
# promote x to at least float32, this avoids half precision computation
@@ -58,15 +59,12 @@ class UNet(nn.Module):
num_channels: int = 32
num_pool_layers: int = 4
out_channels = 1
- dropout_rate: float = 0.0
+ dropout_rate: float = DROPOUT_RATE
use_tanh: bool = False
use_layer_norm: bool = False
@nn.compact
- def __call__(self, x, train=True, dropout_rate=None):
- if dropout_rate is None:
- dropout_rate = self.dropout_rate
-
+ def __call__(self, x, train=True, dropout_rate=DROPOUT_RATE):
# pylint: disable=invalid-name
_ConvBlock = functools.partial(
ConvBlock,
@@ -144,7 +142,7 @@ class ConvBlock(nn.Module):
dropout_rate: float = 0.0
@nn.compact
- def __call__(self, x, train=True, dropout_rate=None):
+ def __call__(self, x, train=True, dropout_rate=DROPOUT_RATE):
"""Forward function.
Note: Pytorch is NCHW and jax/flax is NHWC.
Args:
@@ -153,8 +151,6 @@ def __call__(self, x, train=True, dropout_rate=None):
Returns:
jnp.array: Output tensor of shape `(N, H, W, out_channels)`.
"""
- if dropout_rate is None:
- dropout_rate = self.dropout_rate
x = nn.Conv(
features=self.out_channels,
kernel_size=(3, 3),
diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py
index 8ffc0b610..7c5d7bd26 100644
--- a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py
+++ b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py
@@ -13,6 +13,7 @@
from algoperf import spec
from algoperf.jax_utils import Dropout
+DROPOUT_RATE = 0.0
def posemb_sincos_2d(h: int,
w: int,
@@ -36,17 +37,14 @@ class MlpBlock(nn.Module):
"""Transformer MLP / feed-forward block."""
mlp_dim: Optional[int] = None # Defaults to 4x input dim.
use_glu: bool = False
- dropout_rate: float = 0.0
+ dropout_rate: float = DROPOUT_RATE
@nn.compact
def __call__(self,
x: spec.Tensor,
train: bool = True,
- dropout_rate=None) -> spec.Tensor:
+ dropout_rate=DROPOUT_RATE) -> spec.Tensor:
"""Applies Transformer MlpBlock module."""
- if dropout_rate is None:
- dropout_rate = self.dropout_rate
-
inits = {
'kernel_init': nn.initializers.xavier_uniform(),
'bias_init': nn.initializers.normal(stddev=1e-6),
@@ -78,8 +76,6 @@ def __call__(self,
x: spec.Tensor,
train: bool = True,
dropout_rate=dropout_rate) -> spec.Tensor:
- if dropout_rate is None:
- dropout_rate = self.dropout_rate
if not self.use_post_layer_norm:
y = nn.LayerNorm(name='LayerNorm_0')(x)
@@ -136,11 +132,7 @@ class Encoder(nn.Module):
@nn.compact
def __call__(self,
x: spec.Tensor,
- train: bool = True,
- dropout_rate=None) -> spec.Tensor:
- if dropout_rate is None:
- dropout_rate = self.dropout_rate
-
+ train: bool = True) -> spec.Tensor:
# Input Encoder
for lyr in range(self.depth):
block = Encoder1DBlock(
@@ -165,9 +157,7 @@ class MAPHead(nn.Module):
dropout_rate: float = 0.0
@nn.compact
- def __call__(self, x, dropout_rate=None):
- if dropout_rate is None:
- dropout_rate = self.dropout_rate
+ def __call__(self, x, dropout_rate=DROPOUT_RATE):
n, _, d = x.shape
probe = self.param('probe',
nn.initializers.xavier_uniform(), (1, 1, d),
@@ -194,7 +184,7 @@ class ViT(nn.Module):
mlp_dim: Optional[int] = None # Defaults to 4x input dim.
num_heads: int = 12
rep_size: Union[int, bool] = True
- dropout_rate: Optional[float] = 0.0
+ dropout_rate: [float] = DROPOUT_RATE
reinit: Optional[Sequence[str]] = None
head_zeroinit: bool = True
use_glu: bool = False
@@ -212,9 +202,7 @@ def __call__(self,
x: spec.Tensor,
*,
train: bool = False,
- dropout_rate=None) -> spec.Tensor:
- if dropout_rate is None:
- dropout_rate = self.dropout_rate
+ dropout_rate=DROPOUT_RATE) -> spec.Tensor:
# Patch extraction
x = nn.Conv(
self.width,
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py
index 2d0da15e5..f7beed914 100644
--- a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py
+++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py
@@ -28,6 +28,7 @@
spectrum_augmenter
from algoperf.jax_utils import Dropout
+DROPOUT_RATE = 0.1
@struct.dataclass
class ConformerConfig:
@@ -37,11 +38,7 @@ class ConformerConfig:
encoder_dim: int = 512
num_attention_heads: int = 8
num_encoder_layers: int = 4
- dropout_rate: float = 0.1
- attention_residual_dropout_rate: Optional[float] = 0.0
- conv_residual_dropout_rate: Optional[float] = 0.0
- feed_forward_dropout_rate: float = 0.0
- feed_forward_residual_dropout_rate: Optional[float] = 0.0
+ dropout_rate: float = DROPOUT_RATE
convolution_kernel_size: int = 5
feed_forward_expansion_factor: int = 4
freq_mask_count: int = 2
@@ -96,12 +93,8 @@ class Subsample(nn.Module):
input_dropout_rate: dropout rate for inputs.
"""
encoder_dim: int = 0
- dropout_rate: float = 0.0
-
@nn.compact
- def __call__(self, inputs, input_paddings, train, dropout_rate=None):
- if dropout_rate is None:
- dropout_rate = self.dropout_rate
+ def __call__(self, inputs, input_paddings, train, dropout_rate=DROPOUT_RATE):
output_paddings = input_paddings
outputs = jnp.expand_dims(inputs, axis=-1)
@@ -196,7 +189,7 @@ class FeedForwardModule(nn.Module):
config: ConformerConfig
@nn.compact
- def __call__(self, inputs, padding_mask=None, train=False, dropout_rate=None):
+ def __call__(self, inputs, padding_mask=None, train=False, dropout_rate=DROPOUT_RATE):
config = self.config
if dropout_rate is None:
dropout_rate = config.dropout_rate
@@ -388,10 +381,8 @@ class MultiHeadedSelfAttention(nn.Module):
config: ConformerConfig = None
@nn.compact
- def __call__(self, inputs, paddings, train, dropout_rate=None):
+ def __call__(self, inputs, paddings, train, dropout_rate=DROPOUT_RATE):
config = self.config
- if dropout_rate is None:
- dropout_rate = config.dropout_rate
mask_paddings = 1 - paddings
attention_mask = nn.make_attention_mask(
@@ -527,10 +518,8 @@ def __call__(self,
train,
update_batch_norm,
use_running_average_bn,
- dropout_rate=None):
+ dropout_rate=DROPOUT_RATE):
config = self.config
- if dropout_rate is None:
- dropout_rate = config.dropout_rate
inputs = LayerNorm(dim=config.encoder_dim)(inputs)
input_gated1 = nn.Dense(
@@ -603,7 +592,7 @@ def __call__(self,
train,
update_batch_norm,
use_running_average,
- dropout_rate=None):
+ dropout_rate=DROPOUT_RATE):
config = self.config
padding_mask = jnp.expand_dims(1 - input_paddings, -1)
@@ -658,7 +647,7 @@ def __call__(self,
train,
update_batch_norm: Optional[bool] = None,
use_running_average_bn: Optional[bool] = None,
- dropout_rate: Optional[float] = None):
+ dropout_rate: float = DROPOUT_RATE:
config = self.config
outputs = inputs
diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py
index 455366e5e..84ba58ee2 100644
--- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py
+++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py
@@ -1,4 +1,4 @@
-r"""Deepspeech.
+"""Deepspeech.
This model uses a deepspeech2 network to convert speech to text.
paper : https://arxiv.org/abs/1512.02595
@@ -31,6 +31,8 @@
CarryHistory = Any
Output = Any
+DROPOUT_RATE=0.1
+
@struct.dataclass
class DeepspeechConfig:
@@ -52,10 +54,6 @@ class DeepspeechConfig:
use_dynamic_time_mask_max_frames: bool = True
batch_norm_momentum: float = 0.999
batch_norm_epsilon: float = 0.001
- # If None, defaults to 0.1.
- input_dropout_rate: Optional[float] = 0.1
- # If None, defaults to 0.1.
- feed_forward_dropout_rate: Optional[float] = 0.1
enable_residual_connections: bool = True
enable_decoder_layer_norm: bool = True
bidirectional: bool = True
@@ -73,11 +71,8 @@ class Subsample(nn.Module):
config: DeepspeechConfig
@nn.compact
- def __call__(self, inputs, output_paddings, train, dropout_rate=None):
+ def __call__(self, inputs, output_paddings, train, dropout_rate=DROPOUT_RATE):
config = self.config
- if dropout_rate is None:
- dropout_rate = config.dropout_rate
-
outputs = jnp.expand_dims(inputs, axis=-1)
outputs, output_paddings = Conv2dSubsampling(
@@ -196,9 +191,7 @@ def __call__(self,
inputs,
input_paddings=None,
train=False,
- dropout_rate=None):
- if dropout_rate is None:
- dropout_rate = self.config.feed_forward_dropout_rate
+ dropout_rate=DROPOUT_RATE):
padding_mask = jnp.expand_dims(1 - input_paddings, -1)
config = self.config
@@ -479,10 +472,8 @@ def setup(self):
)
@nn.compact
- def __call__(self, inputs, input_paddings, train, dropout_rate=None):
+ def __call__(self, inputs, input_paddings, train, dropout_rate=DROPOUT_RATE):
config = self.config
- if dropout_rate is None:
- dropout_rate = config.dropout_rate
outputs = inputs
output_paddings = input_paddings
diff --git a/algoperf/workloads/ogbg/ogbg_jax/models.py b/algoperf/workloads/ogbg/ogbg_jax/models.py
index f6cb1c490..59d989284 100644
--- a/algoperf/workloads/ogbg/ogbg_jax/models.py
+++ b/algoperf/workloads/ogbg/ogbg_jax/models.py
@@ -8,6 +8,7 @@
from algoperf.jax_utils import Dropout
+DROPOUT_RATE=0.1
def _make_embed(latent_dim, name):
@@ -41,15 +42,11 @@ class GNN(nn.Module):
num_outputs: int
latent_dim: int = 256
hidden_dims: Tuple[int] = (256,)
- # If None, defaults to 0.1.
- dropout_rate: Optional[float] = 0.1
num_message_passing_steps: int = 5
activation_fn_name: str = 'relu'
@nn.compact
- def __call__(self, graph, train, dropout_rate=None):
- if dropout_rate is not None:
- dropout_rate = self.dropout_rate
+ def __call__(self, graph, train, dropout_rate=DROPOUT_RATE):
dropout = Dropout(dropout_rate, deterministic=not train)(dropout_rate)
graph = graph._replace(
diff --git a/algoperf/workloads/wmt/wmt_jax/models.py b/algoperf/workloads/wmt/wmt_jax/models.py
index 3947a1b81..e262214ac 100644
--- a/algoperf/workloads/wmt/wmt_jax/models.py
+++ b/algoperf/workloads/wmt/wmt_jax/models.py
@@ -14,6 +14,8 @@
from algoperf.jax_utils import Dropout
+DROPOUT_RATE = 0.1
+
@struct.dataclass
class TransformerConfig:
"""Global hyperparameters used to minimize obnoxious kwarg plumbing."""
@@ -28,7 +30,6 @@ class TransformerConfig:
max_len: int = 256
activation: Callable = nn.relu
glu: bool = False
- dropout_rate: Optional[float] = 0.1
attention_temp: float = 1.0
deterministic: bool = False
decode: bool = False
@@ -148,11 +149,9 @@ class MlpBlock(nn.Module):
out_dim: Optional[int] = None
@nn.compact
- def __call__(self, inputs, dropout_rate=None):
+ def __call__(self, inputs, dropout_rate=DROPOUT_RATE):
"""Applies Transformer MlpBlock module."""
cfg = self.config
- if dropout_rate is None:
- dropout_rate = cfg.dropout_rate
actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim
x = nn.Dense(
@@ -195,7 +194,7 @@ class Encoder1DBlock(nn.Module):
config: TransformerConfig
@nn.compact
- def __call__(self, inputs, encoder_mask=None, dropout_rate=None):
+ def __call__(self, inputs, encoder_mask=None, dropout_rate=DROPOUT_RATE):
"""Applies Encoder1DBlock module.
Args:
@@ -206,8 +205,6 @@ def __call__(self, inputs, encoder_mask=None, dropout_rate=None):
output after transformer encoder block.
"""
cfg = self.config
- if dropout_rate is None:
- dropout_rate = cfg.dropout_rate
pre_ln = cfg.pre_ln
@@ -254,7 +251,7 @@ def __call__(
encoded,
decoder_mask=None,
encoder_decoder_mask=None,
- dropout_rate=None,
+ dropout_rate=DROPOUT_RATE,
):
"""Applies EncoderDecoder1DBlock module.
@@ -268,8 +265,6 @@ def __call__(
output after transformer encoder-decoder block.
"""
cfg = self.config
- if dropout_rate is None:
- dropout_rate = cfg.dropout_rate
pre_ln = cfg.pre_ln
@@ -337,7 +332,7 @@ def __call__(self,
inputs,
inputs_positions=None,
encoder_mask=None,
- dropout_rate=None):
+ dropout_rate=DROPOUT_RATE):
"""Applies Transformer model on the inputs.
Args:
@@ -349,8 +344,6 @@ def __call__(self,
output of a transformer encoder.
"""
cfg = self.config
- if dropout_rate is None:
- dropout_rate = cfg.dropout_rate
assert inputs.ndim == 2 # (batch, len)
@@ -403,7 +396,7 @@ def __call__(
targets_positions=None,
decoder_mask=None,
encoder_decoder_mask=None,
- dropout_rate=None,
+ dropout_rate=DROPOUT_RATE,
):
"""Applies Transformer model on the inputs.
@@ -418,8 +411,6 @@ def __call__(
output of a transformer decoder.
"""
cfg = self.config
- if dropout_rate is None:
- dropout_rate = cfg.dropout_rate
assert encoded.ndim == 3 # (batch, len, depth)
assert targets.ndim == 2 # (batch, len)
@@ -495,7 +486,7 @@ def encode(self,
inputs,
inputs_positions=None,
inputs_segmentation=None,
- dropout_rate=None):
+ dropout_rate=DROPOUT_RATE):
"""Applies Transformer encoder-branch on the inputs.
Args:
@@ -533,7 +524,7 @@ def decode(
targets_positions=None,
inputs_segmentation=None,
targets_segmentation=None,
- dropout_rate=None):
+ dropout_rate=DROPOUT_RATE):
"""Applies Transformer decoder-branch on encoded-input and target.
Args:
@@ -593,7 +584,7 @@ def __call__(self,
targets_positions=None,
inputs_segmentation=None,
targets_segmentation=None,
- dropout_rate=None):
+ dropout_rate=DROPOUT_RATE):
"""Applies Transformer model on the inputs.
Args:
From 0128c9fe4caa2590c60af6c7276ab6789d06ae6d Mon Sep 17 00:00:00 2001
From: Niccolo-Ajroldi
Date: Fri, 13 Jun 2025 13:50:00 +0200
Subject: [PATCH 055/123] pipe dropout to model_fn, set default in workload
---
algoperf/spec.py | 6 +++--
.../criteo1tb_pytorch/models_dropout.py | 6 ++---
.../criteo1tb/criteo1tb_pytorch/workload.py | 6 +++--
.../fastmri/fastmri_pytorch/models_dropout.py | 4 +--
.../fastmri/fastmri_pytorch/workload.py | 9 ++++---
.../imagenet_pytorch/workload.py | 9 +++----
.../imagenet_pytorch/models_dropout.py | 4 +--
.../imagenet_vit/imagenet_pytorch/workload.py | 7 +++--
.../librispeech_pytorch/models_dropout.py | 27 +++++--------------
.../librispeech_pytorch/workload.py | 7 +++--
.../librispeech_pytorch/models_dropout.py | 4 +--
.../ogbg/ogbg_pytorch/models_dropout.py | 4 +--
.../workloads/ogbg/ogbg_pytorch/workload.py | 8 ++++--
.../wmt/wmt_pytorch/models_dropout.py | 4 +--
.../workloads/wmt/wmt_pytorch/workload.py | 8 ++++--
15 files changed, 60 insertions(+), 53 deletions(-)
diff --git a/algoperf/spec.py b/algoperf/spec.py
index cf4f1a14e..9670dcb76 100644
--- a/algoperf/spec.py
+++ b/algoperf/spec.py
@@ -247,7 +247,8 @@ def init_model_fn(self,
# ModelAuxiliaryState,
# ForwardPassMode,
# RandomState,
- # bool],
+ # bool,
+ # float],
# Tensor]
@abc.abstractmethod
def model_fn(self,
@@ -256,7 +257,8 @@ def model_fn(self,
model_state: ModelAuxiliaryState,
mode: ForwardPassMode,
rng: RandomState,
- update_batch_norm: bool) -> Tuple[Tensor, ModelAuxiliaryState]:
+ update_batch_norm: bool,
+ dropout_rate: float) -> Tuple[Tensor, ModelAuxiliaryState]:
"""Return logits_batch"""
# Possible side effect of updating BN.
diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py
index b5ee465e2..f0653a665 100644
--- a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py
+++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py
@@ -7,7 +7,7 @@
from algoperf.pytorch_utils import CustomDropout, SequentialWithDropout
-DEFAULT_DROPOUT_RATE = 0.0
+DROPOUT_RATE = 0.0
class DenseBlock(nn.Module):
@@ -148,7 +148,7 @@ def __init__(self,
0.,
math.sqrt(1. / module.out_features))
- def forward(self, x, dropout_rate=DEFAULT_DROPOUT_RATE):
+ def forward(self, x, dropout_rate=DROPOUT_RATE):
batch_size = x.shape[0]
@@ -269,7 +269,7 @@ def __init__(self,
0.,
math.sqrt(1. / module.out_features))
- def forward(self, x, dropout_rate=DEFAULT_DROPOUT_RATE):
+ def forward(self, x, dropout_rate=DROPOUT_RATE):
batch_size = x.shape[0]
diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py
index b128f5bd5..74cb3e140 100644
--- a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py
+++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py
@@ -15,6 +15,7 @@
BaseCriteo1TbDlrmSmallWorkload
USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup()
+DROPOUT_RATE = 0.0
class Criteo1TbDlrmSmallWorkload(BaseCriteo1TbDlrmSmallWorkload):
@@ -103,7 +104,8 @@ def model_fn(
model_state: spec.ModelAuxiliaryState,
mode: spec.ForwardPassMode,
rng: spec.RandomState,
- update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ update_batch_norm: bool,
+ dropout_rate: float = DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del rng
del update_batch_norm
@@ -123,7 +125,7 @@ def model_fn(
}
with contexts[mode]():
- logits_batch = model(inputs)
+ logits_batch = model(inputs, dropout_rate=dropout_rate)
return logits_batch, None
diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py b/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py
index 0e59e1436..0b8ac5499 100644
--- a/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py
+++ b/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py
@@ -15,7 +15,7 @@
from algoperf import init_utils
from algoperf.pytorch_utils import CustomDropout2d, SequentialWithDropout
-DEFAULT_DROPOUT_RATE = 0.0
+DROPOUT_RATE = 0.0
class UNet(nn.Module):
@@ -77,7 +77,7 @@ def __init__(self,
def forward(
self,
x: Tensor,
- dropout_rate: float = DEFAULT_DROPOUT_RATE) -> Tensor:
+ dropout_rate: float = DROPOUT_RATE) -> Tensor:
stack = []
output = x
diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/workload.py b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py
index 6da0bb0af..6374c62d6 100644
--- a/algoperf/workloads/fastmri/fastmri_pytorch/workload.py
+++ b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py
@@ -19,6 +19,8 @@
USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup()
+DROPOUT_RATE = 0.0
+
class FastMRIWorkload(BaseFastMRIWorkload):
@@ -134,7 +136,8 @@ def model_fn(
model_state: spec.ModelAuxiliaryState,
mode: spec.ForwardPassMode,
rng: spec.RandomState,
- update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ update_batch_norm: bool,
+ dropout_rate: float = DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del rng
del update_batch_norm
@@ -154,8 +157,8 @@ def model_fn(
with contexts[mode]():
logit_batch = model(
- augmented_and_preprocessed_input_batch['inputs'].unsqueeze(
- 1)).squeeze(1)
+ augmented_and_preprocessed_input_batch['inputs'].unsqueeze(1),
+ dropout_rate=dropout_rate).squeeze(1)
return logit_batch, None
diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py
index 372cac7fa..f28eb1762 100644
--- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py
+++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py
@@ -156,10 +156,7 @@ def _build_dataset(
def init_model_fn(
self,
- rng: spec.RandomState,
- dropout_rate: Optional[float] = None) -> spec.ModelInitState:
- """Dropout is unused."""
- del dropout_rate
+ rng: spec.RandomState) -> spec.ModelInitState:
torch.random.manual_seed(rng[0])
if self.use_silu and self.use_gelu:
@@ -192,9 +189,11 @@ def model_fn(
model_state: spec.ModelAuxiliaryState,
mode: spec.ForwardPassMode,
rng: spec.RandomState,
- update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ update_batch_norm: bool,
+ dropout_rate: float = 0.0) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del rng
+ del dropout_rate
model = params
diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py
index 570cee575..60e09edb5 100644
--- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py
+++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py
@@ -16,7 +16,7 @@
from algoperf import spec
from algoperf.workloads.wmt.wmt_pytorch.models_dropout import MultiheadAttention
-DEFAULT_DROPOUT_RATE = 0.0
+DROPOUT_RATE = 0.0
def posemb_sincos_2d(patches: spec.Tensor, temperature=10_000.) -> spec.Tensor:
@@ -348,7 +348,7 @@ def get_posemb(self, x: spec.Tensor) -> spec.Tensor:
def forward(
self,
x: spec.Tensor,
- dropout_rate: float = DEFAULT_DROPOUT_RATE) -> spec.Tensor:
+ dropout_rate: float = DROPOUT_RATE) -> spec.Tensor:
# Patch extraction.
x = self.conv_patch_extract(x)
diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py
index 20bd3828b..8b011071a 100644
--- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py
+++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py
@@ -16,6 +16,7 @@
from algoperf.workloads.imagenet_vit.workload import decode_variant
USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup()
+DROPOUT_RATE = 0.0
# Make sure we inherit from the ViT base workload first.
@@ -51,7 +52,8 @@ def model_fn(
model_state: spec.ModelAuxiliaryState,
mode: spec.ForwardPassMode,
rng: spec.RandomState,
- update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ update_batch_norm: bool,
+ dropout_rate: float = DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del rng
del update_batch_norm
@@ -70,7 +72,8 @@ def model_fn(
}
with contexts[mode]():
- logits_batch = model(augmented_and_preprocessed_input_batch['inputs'])
+ logits_batch = model(augmented_and_preprocessed_input_batch['inputs'],
+ dropout_rate=dropout_rate)
return logits_batch, None
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py
index f77c8a814..a6a60bf95 100644
--- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py
+++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py
@@ -17,10 +17,7 @@
from algoperf.workloads.librispeech_conformer.librispeech_pytorch.spectrum_augmenter import \
SpecAug
-DEFAULT_ATTN_RESIDUAL_DROPOUT_RATE = 0.1
-DEFAULT_CONV_RESIDUAL_DROPOUT_RATE = 0.0
-DEFAULT_FFN_RESIDUAL_DROPOUT_RATE = 0.1
-DEFAULT_INPUT_DROPOUT_RATE = 0.1
+DROPOUT_RATE = 0.1
@dataclass
@@ -93,9 +90,7 @@ def __init__(self,
bias=True)
self.pos_encode = AddPositionalEmbedding(embedding_dim=self.encoder_dim)
- def forward(self, inputs, input_paddings, dropout_rate=None):
- if dropout_rate is None:
- dropout_rate = DEFAULT_INPUT_DROPOUT_RATE
+ def forward(self, inputs, input_paddings, dropout_rate):
output_paddings = input_paddings
outputs = inputs[:, None, :, :]
@@ -202,9 +197,7 @@ def __init__(self, config: ConformerConfig):
out_features=config.encoder_dim,
bias=True)
- def forward(self, inputs, padding_mask, dropout_rate=None):
- if dropout_rate is None:
- dropout_rate = DEFAULT_FFN_RESIDUAL_DROPOUT_RATE
+ def forward(self, inputs, padding_mask, dropout_rate):
inputs = self.ln(inputs)
inputs = self.linear1(inputs)
@@ -310,10 +303,7 @@ def __init__(self, config: ConformerConfig):
self.ln = LayerNorm(dim=config.encoder_dim)
self.self_attention = MHSAwithQS(config)
- def forward(self, outputs, paddings, dropout_rate=None):
- if dropout_rate is None:
- dropout_rate = DEFAULT_ATTN_RESIDUAL_DROPOUT_RATE
-
+ def forward(self, outputs, paddings, dropout_rate):
outputs = self.ln(outputs)
outputs = self.self_attention(
outputs,
@@ -400,10 +390,7 @@ def __init__(self, config):
self.bn = BatchNorm(config)
self.lin3 = nn.Linear(config.encoder_dim, config.encoder_dim)
- def forward(self, inputs, input_paddings, dropout_rate=None):
- if dropout_rate is None:
- dropout_rate = DEFAULT_CONV_RESIDUAL_DROPOUT_RATE
-
+ def forward(self, inputs, input_paddings, dropout_rate):
inputs = self.ln(inputs)
inputs = F.glu(torch.cat([self.lin1(inputs), self.lin2(inputs)], dim=2))
@@ -442,7 +429,7 @@ def __init__(self, config: ConformerConfig):
if config.use_post_layer_norm:
self.ln = LayerNorm(dim=config.encoder_dim)
- def forward(self, inputs, input_paddings, dropout_rate=None):
+ def forward(self, inputs, input_paddings, dropout_rate):
padding_mask = 1 - input_paddings[:, :, None]
inputs = inputs + 0.5 * self.ff1(inputs, padding_mask, dropout_rate)
inputs = inputs + self.mhsa(inputs, input_paddings, dropout_rate)
@@ -481,7 +468,7 @@ def __init__(self, config: ConformerConfig):
self.ln = LayerNorm(config.encoder_dim)
self.lin = nn.Linear(config.encoder_dim, config.vocab_size)
- def forward(self, inputs, input_paddings, dropout_rate=None):
+ def forward(self, inputs, input_paddings, dropout_rate=DROPOUT_RATE):
outputs = inputs
output_paddings = input_paddings
outputs, output_paddings = self.preprocessor(outputs, output_paddings)
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py
index 2d0942fe9..d99bc1608 100644
--- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py
+++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py
@@ -24,6 +24,7 @@
USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup()
MAX_INPUT_LENGTH = 320000
+DROPOUT_RATE = 0.1
class LibriSpeechConformerWorkload(workload.BaseLibrispeechWorkload):
@@ -105,7 +106,8 @@ def model_fn(
model_state: spec.ModelAuxiliaryState,
mode: spec.ForwardPassMode,
rng: spec.RandomState,
- update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ update_batch_norm: bool,
+ dropout_rate: float = DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del rng
@@ -126,7 +128,8 @@ def model_fn(
with contexts[mode]():
inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs']
logits, logits_paddings = model(inputs.to(DEVICE),
- input_paddings.to(DEVICE))
+ input_paddings.to(DEVICE),
+ dropout_rate=dropout_rate)
return (logits, logits_paddings), None
def _build_input_queue(
diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py
index 21a4df614..a8480a343 100644
--- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py
+++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py
@@ -17,7 +17,7 @@
SpecAug
USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ
-DEFAULT_DROPOUT_RATE = 0.1
+DROPOUT_RATE = 0.1
@dataclass
@@ -351,7 +351,7 @@ def __init__(self, config: DeepspeechConfig):
self.lin = nn.Linear(config.encoder_dim, config.vocab_size)
- def forward(self, inputs, input_paddings, dropout_rate=DEFAULT_DROPOUT_RATE):
+ def forward(self, inputs, input_paddings, dropout_rate=DROPOUT_RATE):
outputs = inputs
output_paddings = input_paddings
diff --git a/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py b/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py
index c8ed23dda..be5882333 100644
--- a/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py
+++ b/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py
@@ -11,7 +11,7 @@
from algoperf import init_utils
from algoperf.pytorch_utils import CustomDropout, SequentialWithDropout
-DEFAULT_DROPOUT_RATE = 0.1
+DROPOUT_RATE = 0.1
def _make_mlp(in_dim, hidden_dims, activation_fn):
@@ -96,7 +96,7 @@ def __init__(self,
def forward(
self,
graph: GraphsTuple,
- dropout_rate: float = DEFAULT_DROPOUT_RATE) -> torch.Tensor:
+ dropout_rate: float = DROPOUT_RATE) -> torch.Tensor:
graph = graph._replace(
globals=torch.zeros([graph.n_node.shape[0], self.num_outputs],
diff --git a/algoperf/workloads/ogbg/ogbg_pytorch/workload.py b/algoperf/workloads/ogbg/ogbg_pytorch/workload.py
index 7ead696ce..281f4cd08 100644
--- a/algoperf/workloads/ogbg/ogbg_pytorch/workload.py
+++ b/algoperf/workloads/ogbg/ogbg_pytorch/workload.py
@@ -17,6 +17,8 @@
USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup()
+DROPOUT_RATE = 0.1
+
def _pytorch_map(inputs: Any) -> Any:
if USE_PYTORCH_DDP:
@@ -166,7 +168,8 @@ def model_fn(
model_state: spec.ModelAuxiliaryState,
mode: spec.ForwardPassMode,
rng: spec.RandomState,
- update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ update_batch_norm: bool,
+ dropout_rate: float = DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
"""Get predicted logits from the network for input graphs."""
del rng
del update_batch_norm # No BN in the GNN model.
@@ -186,7 +189,8 @@ def model_fn(
}
with contexts[mode]():
- logits = model(augmented_and_preprocessed_input_batch['inputs'])
+ logits = model(augmented_and_preprocessed_input_batch['inputs'],
+ dropout_rate=dropout_rate)
return logits, None
diff --git a/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py b/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py
index a5d822669..a43df30d4 100644
--- a/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py
+++ b/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py
@@ -9,7 +9,7 @@
from torch.nn.init import normal_
from torch.nn.init import xavier_uniform_
-DEFAULT_DROPOUT_RATE = 0.1
+DROPOUT_RATE = 0.1
def make_causal_mask(x: Tensor, device: str = 'cuda:0') -> Tensor:
@@ -156,7 +156,7 @@ def forward(self,
inputs_segmentation: Optional[Tensor] = None,
targets_segmentation: Optional[Tensor] = None,
decode: bool = False,
- dropout_rate: float = DEFAULT_DROPOUT_RATE) -> Tensor:
+ dropout_rate: float = DROPOUT_RATE) -> Tensor:
"""
Args:
src: Tensor, shape [batch_size, seq_len]
diff --git a/algoperf/workloads/wmt/wmt_pytorch/workload.py b/algoperf/workloads/wmt/wmt_pytorch/workload.py
index bb9c3834f..d30abc4c7 100644
--- a/algoperf/workloads/wmt/wmt_pytorch/workload.py
+++ b/algoperf/workloads/wmt/wmt_pytorch/workload.py
@@ -22,6 +22,8 @@
USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup()
+DROPOUT_RATE = 0.1
+
class WmtWorkload(BaseWmtWorkload):
"""WMT PyTorch workload."""
@@ -202,7 +204,8 @@ def model_fn(
model_state: spec.ModelAuxiliaryState,
mode: spec.ForwardPassMode,
rng: spec.RandomState,
- update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ update_batch_norm: bool,
+ dropout_rate: float = DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del rng
del update_batch_norm
@@ -228,7 +231,8 @@ def model_fn(
inputs_segmentation=augmented_and_preprocessed_input_batch.get(
'inputs_segmentation', None),
targets_segmentation=augmented_and_preprocessed_input_batch.get(
- 'targets_segmentation', None))
+ 'targets_segmentation', None),
+ dropout_rate=dropout_rate)
return logits_batch, None
From a7cba1a1acc9da53b9500cab8755f49c88bdbb2c Mon Sep 17 00:00:00 2001
From: Niccolo-Ajroldi
Date: Fri, 13 Jun 2025 14:19:36 +0200
Subject: [PATCH 056/123] remove aux_dropout from pytorch workloads
---
.../test_model_equivalence.py | 32 +------------------
.../test_model_equivalence.py | 2 +-
2 files changed, 2 insertions(+), 32 deletions(-)
diff --git a/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py b/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py
index a4238bbc9..4a1252a39 100644
--- a/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py
+++ b/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py
@@ -3,11 +3,7 @@
Run with:
python3 tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py
-`dropout_rate` controls the following args:
-- `attention_residual_dropout_rate` (if None, 0.1
-- `conv_residual_dropout_rate` (if None, 0.0)
-- `feed_forward_residual_dropout_rate` (if None, 0.1)
-- `input_dropout_rate` (if None, 0.1)
+NOTE: we don't test for default dropout_rate values, since they changed.
"""
from absl.testing import absltest, parameterized
@@ -85,31 +81,5 @@ def test_forward(self, dropout_rate):
assert_close(y1, y2, atol=0, rtol=0)
assert_close(p1, p2, atol=0, rtol=0)
- @parameterized.named_parameters(
- dict(testcase_name=''),
- )
- def test_default_dropout(self):
- """Test default dropout_rate."""
-
- torch.manual_seed(SEED)
- orig = OriginalModel(OriginalConfig(num_encoder_layers=N_LAYERS)).to(DEVICE)
- torch.manual_seed(SEED)
- cust = CustomModel(CustomConfig(num_encoder_layers=N_LAYERS)).to(DEVICE)
- orig.load_state_dict(cust.state_dict())
-
- x = torch.randn(B, T, device=DEVICE)
- paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE)
- for mode in ('train', 'eval'):
- getattr(orig, mode)()
- getattr(cust, mode)()
-
- torch.manual_seed(SEED)
- y1, p1 = orig(x, paddings)
- torch.manual_seed(SEED)
- y2, p2 = cust(x, paddings)
-
- assert_close(y1, y2, atol=0, rtol=0)
- assert_close(p1, p2, atol=0, rtol=0)
-
if __name__ == '__main__':
absltest.main()
diff --git a/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py b/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py
index acdc8c5b3..58ddb354e 100644
--- a/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py
+++ b/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py
@@ -24,7 +24,7 @@
B, T = 32, 30_000
DEVICE = 'cuda'
-TORCH_COMPILE = True
+TORCH_COMPILE = False
os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:8"
torch.backends.cudnn.benchmark = False
From 05bff916dee7de6852afc6d95e2564ad57aa77ef Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Fri, 13 Jun 2025 20:45:13 +0000
Subject: [PATCH 057/123] fix to model_fn default dropout value
---
algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py | 2 +-
algoperf/workloads/fastmri/fastmri_jax/workload.py | 2 +-
algoperf/workloads/imagenet_vit/imagenet_jax/workload.py | 2 +-
.../workloads/librispeech_conformer/librispeech_jax/workload.py | 2 +-
.../librispeech_deepspeech/librispeech_jax/workload.py | 2 +-
algoperf/workloads/ogbg/ogbg_jax/workload.py | 2 +-
algoperf/workloads/wmt/wmt_jax/workload.py | 2 +-
7 files changed, 7 insertions(+), 7 deletions(-)
diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py
index dcb7b9a57..cb7e8cf9f 100644
--- a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py
+++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py
@@ -115,7 +115,7 @@ def model_fn(
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool,
- dropout_rate: float = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del update_batch_norm
inputs = augmented_and_preprocessed_input_batch['inputs']
diff --git a/algoperf/workloads/fastmri/fastmri_jax/workload.py b/algoperf/workloads/fastmri/fastmri_jax/workload.py
index bf0acfc8d..acdf077e1 100644
--- a/algoperf/workloads/fastmri/fastmri_jax/workload.py
+++ b/algoperf/workloads/fastmri/fastmri_jax/workload.py
@@ -52,7 +52,7 @@ def model_fn(
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool,
- dropout_rate: float = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del update_batch_norm
train = mode == spec.ForwardPassMode.TRAIN
diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py
index b8a870de5..08a8f4eb1 100644
--- a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py
+++ b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py
@@ -57,7 +57,7 @@ def model_fn(
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool,
- dropout_rate: float = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del update_batch_norm
train = mode == spec.ForwardPassMode.TRAIN
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py
index 042dba7f4..8d966ef87 100644
--- a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py
+++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py
@@ -109,7 +109,7 @@ def model_fn(
rng: spec.RandomState,
update_batch_norm: bool,
use_running_average_bn: Optional[bool] = None,
- dropout_rate: Optional[float] = None,
+ dropout_rate: Optional[float] = models.DROPOUT_RATE,
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
variables = {'params': params, **model_state}
inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs']
diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py
index 2213f189e..2bb119439 100644
--- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py
+++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py
@@ -57,7 +57,7 @@ def model_fn(
rng: spec.RandomState,
update_batch_norm: bool,
use_running_average_bn: Optional[bool] = None,
- dropout_rate: Optional[bool] = None
+ dropout_rate: Optional[bool] = models.DROPOUT_RATE,
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
variables = {'params': params, **model_state}
inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs']
diff --git a/algoperf/workloads/ogbg/ogbg_jax/workload.py b/algoperf/workloads/ogbg/ogbg_jax/workload.py
index aaa5b4064..e03252ed9 100644
--- a/algoperf/workloads/ogbg/ogbg_jax/workload.py
+++ b/algoperf/workloads/ogbg/ogbg_jax/workload.py
@@ -53,7 +53,7 @@ def model_fn(
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool,
- dropout_rate: Optional[float]) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
"""Get predicted logits from the network for input graphs."""
del update_batch_norm # No BN in the GNN model.
if model_state is not None:
diff --git a/algoperf/workloads/wmt/wmt_jax/workload.py b/algoperf/workloads/wmt/wmt_jax/workload.py
index 9e109dc86..9548f5b7e 100644
--- a/algoperf/workloads/wmt/wmt_jax/workload.py
+++ b/algoperf/workloads/wmt/wmt_jax/workload.py
@@ -250,7 +250,7 @@ def model_fn(
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool,
- dropout_rate: Optional[float] = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ dropout_rate: [float] = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del update_batch_norm
From d8e39b0da371abbcf311ce1d09e06439bd5a0eec Mon Sep 17 00:00:00 2001
From: Niccolo-Ajroldi
Date: Sun, 15 Jun 2025 17:08:50 +0200
Subject: [PATCH 058/123] fix to model_fn default dropout_rate
---
.../criteo1tb/criteo1tb_pytorch/workload.py | 3 +--
.../fastmri/fastmri_pytorch/workload.py | 5 ++---
.../imagenet_vit/imagenet_pytorch/workload.py | 3 +--
.../librispeech_pytorch/workload.py | 3 +--
.../librispeech_pytorch/workload.py | 17 ++++++++++++++++-
.../workloads/ogbg/ogbg_pytorch/workload.py | 5 ++---
algoperf/workloads/wmt/wmt_pytorch/workload.py | 5 ++---
7 files changed, 25 insertions(+), 16 deletions(-)
diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py
index 74cb3e140..48c6592f2 100644
--- a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py
+++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py
@@ -15,7 +15,6 @@
BaseCriteo1TbDlrmSmallWorkload
USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup()
-DROPOUT_RATE = 0.0
class Criteo1TbDlrmSmallWorkload(BaseCriteo1TbDlrmSmallWorkload):
@@ -105,7 +104,7 @@ def model_fn(
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool,
- dropout_rate: float = DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del rng
del update_batch_norm
diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/workload.py b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py
index 6374c62d6..9b96230fc 100644
--- a/algoperf/workloads/fastmri/fastmri_pytorch/workload.py
+++ b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py
@@ -13,14 +13,13 @@
from algoperf import pytorch_utils
from algoperf import spec
import algoperf.random_utils as prng
+from algoperf.workloads.fastmri.fastmri_pytorch import models
from algoperf.workloads.fastmri.fastmri_pytorch.models import UNet
from algoperf.workloads.fastmri.fastmri_pytorch.ssim import ssim
from algoperf.workloads.fastmri.workload import BaseFastMRIWorkload
USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup()
-DROPOUT_RATE = 0.0
-
class FastMRIWorkload(BaseFastMRIWorkload):
@@ -137,7 +136,7 @@ def model_fn(
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool,
- dropout_rate: float = DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del rng
del update_batch_norm
diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py
index 8b011071a..f86a1b1c2 100644
--- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py
+++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py
@@ -16,7 +16,6 @@
from algoperf.workloads.imagenet_vit.workload import decode_variant
USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup()
-DROPOUT_RATE = 0.0
# Make sure we inherit from the ViT base workload first.
@@ -53,7 +52,7 @@ def model_fn(
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool,
- dropout_rate: float = DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del rng
del update_batch_norm
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py
index d99bc1608..0477a7389 100644
--- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py
+++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py
@@ -24,7 +24,6 @@
USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup()
MAX_INPUT_LENGTH = 320000
-DROPOUT_RATE = 0.1
class LibriSpeechConformerWorkload(workload.BaseLibrispeechWorkload):
@@ -107,7 +106,7 @@ def model_fn(
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool,
- dropout_rate: float = DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del rng
diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py
index e6ec4764f..bf345cfc9 100644
--- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py
+++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py
@@ -1,4 +1,4 @@
-from typing import Optional
+from typing import Dict, Tuple
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
@@ -6,6 +6,7 @@
from algoperf import param_utils
from algoperf import spec
from algoperf.pytorch_utils import pytorch_setup
+from algoperf.workloads.librispeech_conformer.librispeech_pytorch import models
from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \
initialize
from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import \
@@ -54,6 +55,20 @@ def init_model_fn(
else:
model = torch.nn.DataParallel(model)
return model, None
+
+ def model_fn(
+ self,
+ params: spec.ParameterContainer,
+ augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ mode: spec.ForwardPassMode,
+ rng: spec.RandomState,
+ update_batch_norm: bool,
+ dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ # override super method, changing only the default dropout_rate
+ return super().model_fn(
+ params, augmented_and_preprocessed_input_batch, model_state,
+ mode, rng, update_batch_norm, dropout_rate)
def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key in ['lin.weight', 'lin.bias']
diff --git a/algoperf/workloads/ogbg/ogbg_pytorch/workload.py b/algoperf/workloads/ogbg/ogbg_pytorch/workload.py
index 281f4cd08..758b36b60 100644
--- a/algoperf/workloads/ogbg/ogbg_pytorch/workload.py
+++ b/algoperf/workloads/ogbg/ogbg_pytorch/workload.py
@@ -12,13 +12,12 @@
from algoperf import pytorch_utils
from algoperf import spec
from algoperf.workloads.ogbg import metrics
+from algoperf.workloads.ogbg.ogbg_pytorch import models
from algoperf.workloads.ogbg.ogbg_pytorch.models import GNN
from algoperf.workloads.ogbg.workload import BaseOgbgWorkload
USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup()
-DROPOUT_RATE = 0.1
-
def _pytorch_map(inputs: Any) -> Any:
if USE_PYTORCH_DDP:
@@ -169,7 +168,7 @@ def model_fn(
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool,
- dropout_rate: float = DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
"""Get predicted logits from the network for input graphs."""
del rng
del update_batch_norm # No BN in the GNN model.
diff --git a/algoperf/workloads/wmt/wmt_pytorch/workload.py b/algoperf/workloads/wmt/wmt_pytorch/workload.py
index d30abc4c7..4c787becc 100644
--- a/algoperf/workloads/wmt/wmt_pytorch/workload.py
+++ b/algoperf/workloads/wmt/wmt_pytorch/workload.py
@@ -17,13 +17,12 @@
from algoperf import spec
from algoperf.workloads.wmt import bleu
from algoperf.workloads.wmt.wmt_pytorch import decode
+from algoperf.workloads.wmt.wmt_pytorch import models
from algoperf.workloads.wmt.wmt_pytorch.models import Transformer
from algoperf.workloads.wmt.workload import BaseWmtWorkload
USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup()
-DROPOUT_RATE = 0.1
-
class WmtWorkload(BaseWmtWorkload):
"""WMT PyTorch workload."""
@@ -205,7 +204,7 @@ def model_fn(
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool,
- dropout_rate: float = DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del rng
del update_batch_norm
From 7a0015830e840211b8002f068b1eb918c8390f5c Mon Sep 17 00:00:00 2001
From: Niccolo-Ajroldi
Date: Sun, 15 Jun 2025 17:14:18 +0200
Subject: [PATCH 059/123] rm models_dropout torch files
---
.../criteo1tb/criteo1tb_pytorch/models.py | 54 +-
.../criteo1tb_pytorch/models_dropout.py | 297 ------
.../fastmri/fastmri_pytorch/models.py | 44 +-
.../fastmri/fastmri_pytorch/models_dropout.py | 173 ---
.../imagenet_vit/imagenet_pytorch/models.py | 84 +-
.../imagenet_pytorch/models_dropout.py | 378 -------
.../librispeech_pytorch/models.py | 70 +-
.../librispeech_pytorch/models_dropout.py | 482 ---------
.../librispeech_pytorch/models.py | 32 +-
.../librispeech_pytorch/models_dropout.py | 379 -------
.../workloads/ogbg/ogbg_pytorch/models.py | 35 +-
.../ogbg/ogbg_pytorch/models_dropout.py | 315 ------
algoperf/workloads/wmt/wmt_pytorch/models.py | 181 ++--
.../wmt/wmt_pytorch/models_dropout.py | 989 ------------------
14 files changed, 234 insertions(+), 3279 deletions(-)
delete mode 100644 algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py
delete mode 100644 algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py
delete mode 100644 algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py
delete mode 100644 algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py
delete mode 100644 algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py
delete mode 100644 algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py
delete mode 100644 algoperf/workloads/wmt/wmt_pytorch/models_dropout.py
diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models.py
index 7a40f0e81..f0653a665 100644
--- a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models.py
+++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models.py
@@ -5,20 +5,32 @@
import torch
from torch import nn
+from algoperf.pytorch_utils import CustomDropout, SequentialWithDropout
+
+DROPOUT_RATE = 0.0
+
class DenseBlock(nn.Module):
"""Dense block with optional residual connection.""" ""
-
def __init__(self, module, resnet=False):
super().__init__()
self.module = module
self.resnet = resnet
def forward(self, x):
- if self.resnet:
- return self.module(x) + x
- else:
- return self.module(x)
+ return self.module(x) + x if self.resnet else self.module(x)
+
+
+class DenseBlockWithDropout(nn.Module):
+ """Dense block with optional residual connection and support for dropout."""
+ def __init__(self, module, resnet=False):
+ super().__init__()
+ self.module = module
+ self.resnet = resnet
+ self._supports_custom_dropout = True
+
+ def forward(self, x, p):
+ return self.module(x, p) + x if self.resnet else self.module(x, p)
class DotInteract(nn.Module):
@@ -58,7 +70,6 @@ def __init__(self,
mlp_bottom_dims=(256, 256, 256),
mlp_top_dims=(256, 256, 256, 256, 1),
embed_dim=128,
- dropout_rate=0.0,
use_layer_norm=False,
embedding_init_multiplier=None):
super().__init__()
@@ -116,17 +127,16 @@ def __init__(self,
block.append(nn.Linear(fan_in, fan_out))
if layer_idx < (num_layers_top - 1):
block.append(nn.ReLU(inplace=True))
- if (dropout_rate is not None and dropout_rate > 0.0 and
- layer_idx == num_layers_top - 2):
- block.append(nn.Dropout(p=dropout_rate))
- block = nn.Sequential(*block)
+ if layer_idx == num_layers_top - 2:
+ block.append(CustomDropout())
+ block = SequentialWithDropout(*block)
if (layer_idx != 0) and (layer_idx != num_layers_top - 1):
- block = DenseBlock(block, resnet=True)
+ block = DenseBlockWithDropout(block, resnet=True)
else:
- block = DenseBlock(block)
+ block = DenseBlockWithDropout(block)
mlp_top_blocks.append(block)
fan_in = fan_out
- self.top_mlp = nn.Sequential(*mlp_top_blocks)
+ self.top_mlp = SequentialWithDropout(*mlp_top_blocks)
for module in self.top_mlp.modules():
if isinstance(module, nn.Linear):
@@ -138,7 +148,8 @@ def __init__(self,
0.,
math.sqrt(1. / module.out_features))
- def forward(self, x):
+ def forward(self, x, dropout_rate=DROPOUT_RATE):
+
batch_size = x.shape[0]
dense_features, sparse_features = torch.split(
@@ -157,7 +168,7 @@ def forward(self, x):
top_mlp_input = torch.cat([embedded_dense, embedded_sparse], axis=1)
# Final MLP.
- logits = self.top_mlp(top_mlp_input)
+ logits = self.top_mlp(top_mlp_input, dropout_rate)
return logits
@@ -179,7 +190,6 @@ def __init__(self,
mlp_bottom_dims=(512, 256, 128),
mlp_top_dims=(1024, 1024, 512, 256, 1),
embed_dim=128,
- dropout_rate=0.0,
use_layer_norm=False,
embedding_init_multiplier=None):
super().__init__()
@@ -242,10 +252,9 @@ def __init__(self,
top_mlp_layers.append(nn.ReLU(inplace=True))
if use_layer_norm:
top_mlp_layers.append(nn.LayerNorm(fan_out, eps=1e-6))
- if (dropout_rate is not None and dropout_rate > 0.0 and
- layer_idx == num_layers_top - 2):
- top_mlp_layers.append(nn.Dropout(p=dropout_rate))
- self.top_mlp = nn.Sequential(*top_mlp_layers)
+ if layer_idx == num_layers_top - 2:
+ top_mlp_layers.append(CustomDropout())
+ self.top_mlp = SequentialWithDropout(*top_mlp_layers)
if use_layer_norm:
self.embed_ln = nn.LayerNorm(self.embed_dim, eps=1e-6)
else:
@@ -260,7 +269,8 @@ def __init__(self,
0.,
math.sqrt(1. / module.out_features))
- def forward(self, x):
+ def forward(self, x, dropout_rate=DROPOUT_RATE):
+
batch_size = x.shape[0]
dense_features, sparse_features = torch.split(
@@ -283,5 +293,5 @@ def forward(self, x):
dense_features=embedded_dense, sparse_features=embedded_sparse)
# Final MLP.
- logits = self.top_mlp(concatenated_dense)
+ logits = self.top_mlp(concatenated_dense, dropout_rate)
return logits
diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py
deleted file mode 100644
index f0653a665..000000000
--- a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models_dropout.py
+++ /dev/null
@@ -1,297 +0,0 @@
-"""Pytorch implementation of DLRM-Small."""
-
-import math
-
-import torch
-from torch import nn
-
-from algoperf.pytorch_utils import CustomDropout, SequentialWithDropout
-
-DROPOUT_RATE = 0.0
-
-
-class DenseBlock(nn.Module):
- """Dense block with optional residual connection.""" ""
- def __init__(self, module, resnet=False):
- super().__init__()
- self.module = module
- self.resnet = resnet
-
- def forward(self, x):
- return self.module(x) + x if self.resnet else self.module(x)
-
-
-class DenseBlockWithDropout(nn.Module):
- """Dense block with optional residual connection and support for dropout."""
- def __init__(self, module, resnet=False):
- super().__init__()
- self.module = module
- self.resnet = resnet
- self._supports_custom_dropout = True
-
- def forward(self, x, p):
- return self.module(x, p) + x if self.resnet else self.module(x, p)
-
-
-class DotInteract(nn.Module):
- """Performs feature interaction operation between dense or sparse features."""
-
- def __init__(self, num_sparse_features):
- super().__init__()
- self.triu_indices = torch.triu_indices(num_sparse_features + 1,
- num_sparse_features + 1)
-
- def forward(self, dense_features, sparse_features):
- combined_values = torch.cat((dense_features.unsqueeze(1), sparse_features),
- dim=1)
- interactions = torch.bmm(combined_values,
- torch.transpose(combined_values, 1, 2))
- interactions_flat = interactions[:,
- self.triu_indices[0],
- self.triu_indices[1]]
- return torch.cat((dense_features, interactions_flat), dim=1)
-
-
-class DLRMResNet(nn.Module):
- """Define a DLRM-Small model.
-
- Parameters:
- vocab_size: vocab size of embedding table.
- num_dense_features: number of dense features as the bottom mlp input.
- mlp_bottom_dims: dimensions of dense layers of the bottom mlp.
- mlp_top_dims: dimensions of dense layers of the top mlp.
- embed_dim: embedding dimension.
- """
-
- def __init__(self,
- vocab_size,
- num_dense_features=13,
- num_sparse_features=26,
- mlp_bottom_dims=(256, 256, 256),
- mlp_top_dims=(256, 256, 256, 256, 1),
- embed_dim=128,
- use_layer_norm=False,
- embedding_init_multiplier=None):
- super().__init__()
- self.vocab_size = torch.tensor(vocab_size, dtype=torch.int32)
- self.num_dense_features = num_dense_features
- self.num_sparse_features = num_sparse_features
- self.mlp_bottom_dims = mlp_bottom_dims
- self.mlp_top_dims = mlp_top_dims
- self.embed_dim = embed_dim
-
- # Ideally, we should use the pooled embedding implementation from
- # `TorchRec`. However, in order to have identical implementation
- # with that of Jax, we define a single embedding matrix.
- num_chunks = 4
- assert vocab_size % num_chunks == 0
- self.embedding_table_chucks = []
- scale = 1.0 / torch.sqrt(self.vocab_size)
- for i in range(num_chunks):
- chunk = nn.Parameter(
- torch.Tensor(self.vocab_size // num_chunks, self.embed_dim))
- chunk.data.uniform_(0, 1)
- chunk.data = scale * chunk.data
- self.register_parameter(f'embedding_chunk_{i}', chunk)
- self.embedding_table_chucks.append(chunk)
-
- input_dim = self.num_dense_features
- bot_mlp_blocks = []
- for layer_idx, dense_dim in enumerate(self.mlp_bottom_dims):
- block = []
- block.append(nn.Linear(input_dim, dense_dim))
- block.append(nn.ReLU(inplace=True))
- block = nn.Sequential(*block)
- if layer_idx > 0:
- block = DenseBlock(block, resnet=True)
- else:
- block = DenseBlock(block)
- bot_mlp_blocks.append(block)
- input_dim = dense_dim
- self.bot_mlp = nn.Sequential(*bot_mlp_blocks)
-
- for module in self.bot_mlp.modules():
- if isinstance(module, nn.Linear):
- limit = math.sqrt(6. / (module.in_features + module.out_features))
- nn.init.uniform_(module.weight.data, -limit, limit)
- nn.init.normal_(module.bias.data,
- 0.,
- math.sqrt(1. / module.out_features))
-
- # Number of sparse features = 26
- fan_in = (26 * self.embed_dim) + self.mlp_bottom_dims[-1]
- num_layers_top = len(self.mlp_top_dims)
- mlp_top_blocks = []
- for layer_idx, fan_out in enumerate(self.mlp_top_dims):
- block = []
- block.append(nn.Linear(fan_in, fan_out))
- if layer_idx < (num_layers_top - 1):
- block.append(nn.ReLU(inplace=True))
- if layer_idx == num_layers_top - 2:
- block.append(CustomDropout())
- block = SequentialWithDropout(*block)
- if (layer_idx != 0) and (layer_idx != num_layers_top - 1):
- block = DenseBlockWithDropout(block, resnet=True)
- else:
- block = DenseBlockWithDropout(block)
- mlp_top_blocks.append(block)
- fan_in = fan_out
- self.top_mlp = SequentialWithDropout(*mlp_top_blocks)
-
- for module in self.top_mlp.modules():
- if isinstance(module, nn.Linear):
- nn.init.normal_(
- module.weight.data,
- 0.,
- math.sqrt(2. / (module.in_features + module.out_features)))
- nn.init.normal_(module.bias.data,
- 0.,
- math.sqrt(1. / module.out_features))
-
- def forward(self, x, dropout_rate=DROPOUT_RATE):
-
- batch_size = x.shape[0]
-
- dense_features, sparse_features = torch.split(
- x, [self.num_dense_features, self.num_sparse_features], 1)
-
- # Bottom MLP.
- embedded_dense = self.bot_mlp(dense_features)
-
- # Sparse feature processing.
- sparse_features = sparse_features.to(dtype=torch.int32)
- idx_lookup = torch.reshape(sparse_features, [-1]) % self.vocab_size
- embedding_table = torch.cat(self.embedding_table_chucks, dim=0)
- embedded_sparse = embedding_table[idx_lookup]
- embedded_sparse = torch.reshape(embedded_sparse,
- [batch_size, 26 * self.embed_dim])
- top_mlp_input = torch.cat([embedded_dense, embedded_sparse], axis=1)
-
- # Final MLP.
- logits = self.top_mlp(top_mlp_input, dropout_rate)
- return logits
-
-
-class DlrmSmall(nn.Module):
- """Define a DLRM-Small model.
-
- Parameters:
- vocab_size: vocab size of embedding table.
- num_dense_features: number of dense features as the bottom mlp input.
- mlp_bottom_dims: dimensions of dense layers of the bottom mlp.
- mlp_top_dims: dimensions of dense layers of the top mlp.
- embed_dim: embedding dimension.
- """
-
- def __init__(self,
- vocab_size,
- num_dense_features=13,
- num_sparse_features=26,
- mlp_bottom_dims=(512, 256, 128),
- mlp_top_dims=(1024, 1024, 512, 256, 1),
- embed_dim=128,
- use_layer_norm=False,
- embedding_init_multiplier=None):
- super().__init__()
- self.vocab_size = torch.tensor(vocab_size, dtype=torch.int32)
- self.num_dense_features = num_dense_features
- self.num_sparse_features = num_sparse_features
- self.mlp_bottom_dims = mlp_bottom_dims
- self.mlp_top_dims = mlp_top_dims
- self.embed_dim = embed_dim
- self.embedding_init_multiplier = embedding_init_multiplier
-
- # Ideally, we should use the pooled embedding implementation from
- # `TorchRec`. However, in order to have identical implementation
- # with that of Jax, we define a single embedding matrix.
- num_chunks = 4
- assert vocab_size % num_chunks == 0
- self.embedding_table_chucks = []
-
- if self.embedding_init_multiplier is None:
- scale = 1.0 / torch.sqrt(self.vocab_size)
- else:
- scale = self.embedding_init_multiplier
-
- for i in range(num_chunks):
- chunk = nn.Parameter(
- torch.Tensor(self.vocab_size // num_chunks, self.embed_dim))
- chunk.data.uniform_(0, 1)
- chunk.data = scale * chunk.data
- self.register_parameter(f'embedding_chunk_{i}', chunk)
- self.embedding_table_chucks.append(chunk)
-
- input_dim = self.num_dense_features
- bottom_mlp_layers = []
- for dense_dim in self.mlp_bottom_dims:
- bottom_mlp_layers.append(nn.Linear(input_dim, dense_dim))
- bottom_mlp_layers.append(nn.ReLU(inplace=True))
- if use_layer_norm:
- bottom_mlp_layers.append(nn.LayerNorm(dense_dim, eps=1e-6))
- input_dim = dense_dim
- self.bot_mlp = nn.Sequential(*bottom_mlp_layers)
- for module in self.bot_mlp.modules():
- if isinstance(module, nn.Linear):
- limit = math.sqrt(6. / (module.in_features + module.out_features))
- nn.init.uniform_(module.weight.data, -limit, limit)
- nn.init.normal_(module.bias.data,
- 0.,
- math.sqrt(1. / module.out_features))
-
- self.dot_interact = DotInteract(num_sparse_features=num_sparse_features,)
-
- # TODO: Write down the formula here instead of the constant.
- input_dims = 506
- num_layers_top = len(self.mlp_top_dims)
- top_mlp_layers = []
- for layer_idx, fan_out in enumerate(self.mlp_top_dims):
- fan_in = input_dims if layer_idx == 0 \
- else self.mlp_top_dims[layer_idx - 1]
- top_mlp_layers.append(nn.Linear(fan_in, fan_out))
- if layer_idx < (num_layers_top - 1):
- top_mlp_layers.append(nn.ReLU(inplace=True))
- if use_layer_norm:
- top_mlp_layers.append(nn.LayerNorm(fan_out, eps=1e-6))
- if layer_idx == num_layers_top - 2:
- top_mlp_layers.append(CustomDropout())
- self.top_mlp = SequentialWithDropout(*top_mlp_layers)
- if use_layer_norm:
- self.embed_ln = nn.LayerNorm(self.embed_dim, eps=1e-6)
- else:
- self.embed_ln = None
- for module in self.top_mlp.modules():
- if isinstance(module, nn.Linear):
- nn.init.normal_(
- module.weight.data,
- 0.,
- math.sqrt(2. / (module.in_features + module.out_features)))
- nn.init.normal_(module.bias.data,
- 0.,
- math.sqrt(1. / module.out_features))
-
- def forward(self, x, dropout_rate=DROPOUT_RATE):
-
- batch_size = x.shape[0]
-
- dense_features, sparse_features = torch.split(
- x, [self.num_dense_features, self.num_sparse_features], 1)
-
- # Bottom MLP.
- embedded_dense = self.bot_mlp(dense_features)
-
- # Sparse feature processing.
- sparse_features = sparse_features.to(dtype=torch.int32)
- idx_lookup = torch.reshape(sparse_features, [-1]) % self.vocab_size
- embedding_table = torch.cat(self.embedding_table_chucks, dim=0)
- embedded_sparse = embedding_table[idx_lookup]
- embedded_sparse = torch.reshape(embedded_sparse,
- [batch_size, -1, self.embed_dim])
- if self.embed_ln:
- embedded_sparse = self.embed_ln(embedded_sparse)
- # Dot product interactions.
- concatenated_dense = self.dot_interact(
- dense_features=embedded_dense, sparse_features=embedded_sparse)
-
- # Final MLP.
- logits = self.top_mlp(concatenated_dense, dropout_rate)
- return logits
diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/models.py b/algoperf/workloads/fastmri/fastmri_pytorch/models.py
index 28f20bf20..0b8ac5499 100644
--- a/algoperf/workloads/fastmri/fastmri_pytorch/models.py
+++ b/algoperf/workloads/fastmri/fastmri_pytorch/models.py
@@ -13,6 +13,9 @@
from torch.nn import functional as F
from algoperf import init_utils
+from algoperf.pytorch_utils import CustomDropout2d, SequentialWithDropout
+
+DROPOUT_RATE = 0.0
class UNet(nn.Module):
@@ -27,7 +30,6 @@ def __init__(self,
out_chans: int = 1,
num_channels: int = 32,
num_pool_layers: int = 4,
- dropout_rate: Optional[float] = 0.0,
use_tanh: bool = False,
use_layer_norm: bool = False) -> None:
super().__init__()
@@ -36,21 +38,19 @@ def __init__(self,
self.out_chans = out_chans
self.num_channels = num_channels
self.num_pool_layers = num_pool_layers
- if dropout_rate is None:
- dropout_rate = 0.0
+
self.down_sample_layers = nn.ModuleList([
ConvBlock(in_chans,
num_channels,
- dropout_rate,
use_tanh,
use_layer_norm)
])
ch = num_channels
for _ in range(num_pool_layers - 1):
self.down_sample_layers.append(
- ConvBlock(ch, ch * 2, dropout_rate, use_tanh, use_layer_norm))
+ ConvBlock(ch, ch * 2, use_tanh, use_layer_norm))
ch *= 2
- self.conv = ConvBlock(ch, ch * 2, dropout_rate, use_tanh, use_layer_norm)
+ self.conv = ConvBlock(ch, ch * 2, use_tanh, use_layer_norm)
self.up_conv = nn.ModuleList()
self.up_transpose_conv = nn.ModuleList()
@@ -59,14 +59,14 @@ def __init__(self,
self.up_transpose_conv.append(
TransposeConvBlock(ch * 2, ch, use_tanh, use_layer_norm))
self.up_conv.append(
- ConvBlock(ch * 2, ch, dropout_rate, use_tanh, use_layer_norm))
+ ConvBlock(ch * 2, ch, use_tanh, use_layer_norm))
ch //= 2
self.up_transpose_conv.append(
TransposeConvBlock(ch * 2, ch, use_tanh, use_layer_norm))
self.up_conv.append(
- nn.Sequential(
- ConvBlock(ch * 2, ch, dropout_rate, use_tanh, use_layer_norm),
+ SequentialWithDropout(
+ ConvBlock(ch * 2, ch, use_tanh, use_layer_norm),
nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1),
))
@@ -74,24 +74,28 @@ def __init__(self,
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
init_utils.pytorch_default_init(m)
- def forward(self, x: Tensor) -> Tensor:
+ def forward(
+ self,
+ x: Tensor,
+ dropout_rate: float = DROPOUT_RATE) -> Tensor:
+
stack = []
output = x
# apply down-sampling layers
for layer in self.down_sample_layers:
- output = layer(output)
+ output = layer(output, dropout_rate)
stack.append(output)
output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0)
- output = self.conv(output)
+ output = self.conv(output, dropout_rate)
# apply up-sampling layers
for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv):
downsample_layer = stack.pop()
output = transpose_conv(output)
- # reflect pad on the right/botton if needed to handle
+ # reflect pad on the right/bottom if needed to handle
# odd input dimensions
padding = [0, 0, 0, 0]
if output.shape[-1] != downsample_layer.shape[-1]:
@@ -102,7 +106,7 @@ def forward(self, x: Tensor) -> Tensor:
output = F.pad(output, padding, "reflect")
output = torch.cat([output, downsample_layer], dim=1)
- output = conv(output)
+ output = conv(output, dropout_rate)
return output
@@ -114,10 +118,10 @@ class ConvBlock(nn.Module):
def __init__(self,
in_chans: int,
out_chans: int,
- dropout_rate: float,
use_tanh: bool,
use_layer_norm: bool) -> None:
super().__init__()
+ self._supports_custom_dropout = True
if use_layer_norm:
norm_layer = partial(nn.GroupNorm, 1, eps=1e-6)
@@ -127,19 +131,19 @@ def __init__(self,
activation_fn = nn.Tanh()
else:
activation_fn = nn.LeakyReLU(negative_slope=0.2, inplace=True)
- self.conv_layers = nn.Sequential(
+ self.conv_layers = SequentialWithDropout(
nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False),
norm_layer(out_chans),
activation_fn,
- nn.Dropout2d(dropout_rate),
+ CustomDropout2d(),
nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False),
norm_layer(out_chans),
activation_fn,
- nn.Dropout2d(dropout_rate),
+ CustomDropout2d(),
)
- def forward(self, x: Tensor) -> Tensor:
- return self.conv_layers(x)
+ def forward(self, x: Tensor, dropout_rate: float) -> Tensor:
+ return self.conv_layers(x, dropout_rate)
class TransposeConvBlock(nn.Module):
diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py b/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py
deleted file mode 100644
index 0b8ac5499..000000000
--- a/algoperf/workloads/fastmri/fastmri_pytorch/models_dropout.py
+++ /dev/null
@@ -1,173 +0,0 @@
-"""U-Net Model.
-
-Adapted from fastMRI:
-https://github.com/facebookresearch/fastMRI/blob/main/fastmri/models/unet.py
-"""
-
-from functools import partial
-from typing import Optional
-
-import torch
-from torch import nn
-from torch import Tensor
-from torch.nn import functional as F
-
-from algoperf import init_utils
-from algoperf.pytorch_utils import CustomDropout2d, SequentialWithDropout
-
-DROPOUT_RATE = 0.0
-
-
-class UNet(nn.Module):
- r"""U-Net model from
- `"U-net: Convolutional networks
- for biomedical image segmentation"
- `_.
- """
-
- def __init__(self,
- in_chans: int = 1,
- out_chans: int = 1,
- num_channels: int = 32,
- num_pool_layers: int = 4,
- use_tanh: bool = False,
- use_layer_norm: bool = False) -> None:
- super().__init__()
-
- self.in_chans = in_chans
- self.out_chans = out_chans
- self.num_channels = num_channels
- self.num_pool_layers = num_pool_layers
-
- self.down_sample_layers = nn.ModuleList([
- ConvBlock(in_chans,
- num_channels,
- use_tanh,
- use_layer_norm)
- ])
- ch = num_channels
- for _ in range(num_pool_layers - 1):
- self.down_sample_layers.append(
- ConvBlock(ch, ch * 2, use_tanh, use_layer_norm))
- ch *= 2
- self.conv = ConvBlock(ch, ch * 2, use_tanh, use_layer_norm)
-
- self.up_conv = nn.ModuleList()
- self.up_transpose_conv = nn.ModuleList()
-
- for _ in range(num_pool_layers - 1):
- self.up_transpose_conv.append(
- TransposeConvBlock(ch * 2, ch, use_tanh, use_layer_norm))
- self.up_conv.append(
- ConvBlock(ch * 2, ch, use_tanh, use_layer_norm))
- ch //= 2
-
- self.up_transpose_conv.append(
- TransposeConvBlock(ch * 2, ch, use_tanh, use_layer_norm))
- self.up_conv.append(
- SequentialWithDropout(
- ConvBlock(ch * 2, ch, use_tanh, use_layer_norm),
- nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1),
- ))
-
- for m in self.modules():
- if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
- init_utils.pytorch_default_init(m)
-
- def forward(
- self,
- x: Tensor,
- dropout_rate: float = DROPOUT_RATE) -> Tensor:
-
- stack = []
- output = x
-
- # apply down-sampling layers
- for layer in self.down_sample_layers:
- output = layer(output, dropout_rate)
- stack.append(output)
- output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0)
-
- output = self.conv(output, dropout_rate)
-
- # apply up-sampling layers
- for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv):
- downsample_layer = stack.pop()
- output = transpose_conv(output)
-
- # reflect pad on the right/bottom if needed to handle
- # odd input dimensions
- padding = [0, 0, 0, 0]
- if output.shape[-1] != downsample_layer.shape[-1]:
- padding[1] = 1 # padding right
- if output.shape[-2] != downsample_layer.shape[-2]:
- padding[3] = 1 # padding bottom
- if torch.sum(torch.tensor(padding)) != 0:
- output = F.pad(output, padding, "reflect")
-
- output = torch.cat([output, downsample_layer], dim=1)
- output = conv(output, dropout_rate)
-
- return output
-
-
-class ConvBlock(nn.Module):
- # A Convolutional Block that consists of two convolution layers each
- # followed by instance normalization, LeakyReLU activation and dropout_rate.
-
- def __init__(self,
- in_chans: int,
- out_chans: int,
- use_tanh: bool,
- use_layer_norm: bool) -> None:
- super().__init__()
- self._supports_custom_dropout = True
-
- if use_layer_norm:
- norm_layer = partial(nn.GroupNorm, 1, eps=1e-6)
- else:
- norm_layer = nn.InstanceNorm2d
- if use_tanh:
- activation_fn = nn.Tanh()
- else:
- activation_fn = nn.LeakyReLU(negative_slope=0.2, inplace=True)
- self.conv_layers = SequentialWithDropout(
- nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False),
- norm_layer(out_chans),
- activation_fn,
- CustomDropout2d(),
- nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False),
- norm_layer(out_chans),
- activation_fn,
- CustomDropout2d(),
- )
-
- def forward(self, x: Tensor, dropout_rate: float) -> Tensor:
- return self.conv_layers(x, dropout_rate)
-
-
-class TransposeConvBlock(nn.Module):
- # A Transpose Convolutional Block that consists of one convolution transpose
- # layers followed by instance normalization and LeakyReLU activation.
-
- def __init__(
- self,
- in_chans: int,
- out_chans: int,
- use_tanh: bool,
- use_layer_norm: bool,
- ):
- super().__init__()
- if use_tanh:
- activation_fn = nn.Tanh()
- else:
- activation_fn = nn.LeakyReLU(negative_slope=0.2, inplace=True)
- self.layers = nn.Sequential(
- nn.ConvTranspose2d(
- in_chans, out_chans, kernel_size=2, stride=2, bias=False),
- nn.InstanceNorm2d(out_chans),
- activation_fn,
- )
-
- def forward(self, x: Tensor) -> Tensor:
- return self.layers(x)
diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py
index fcf0992d3..60e09edb5 100644
--- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py
+++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py
@@ -14,7 +14,9 @@
from algoperf import init_utils
from algoperf import spec
-from algoperf.workloads.wmt.wmt_pytorch.models import MultiheadAttention
+from algoperf.workloads.wmt.wmt_pytorch.models_dropout import MultiheadAttention
+
+DROPOUT_RATE = 0.0
def posemb_sincos_2d(patches: spec.Tensor, temperature=10_000.) -> spec.Tensor:
@@ -41,18 +43,15 @@ def __init__(
self,
width: int,
mlp_dim: Optional[int] = None, # Defaults to 4x input dim.
- use_glu: bool = False,
- dropout_rate: float = 0.0) -> None:
+ use_glu: bool = False) -> None:
super().__init__()
self.width = width
self.mlp_dim = mlp_dim or 4 * width
self.use_glu = use_glu
- self.dropout_rate = dropout_rate
self.linear1 = nn.Linear(self.width, self.mlp_dim)
self.act_fnc = nn.GELU(approximate='tanh')
- self.dropout = nn.Dropout(self.dropout_rate)
if self.use_glu:
self.glu_linear = nn.Linear(self.mlp_dim, self.mlp_dim)
@@ -70,7 +69,8 @@ def reset_parameters(self) -> None:
if module.bias is not None:
module.bias.data.normal_(std=1e-6)
- def forward(self, x: spec.Tensor) -> spec.Tensor:
+ def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor:
+
x = self.linear1(x)
x = self.act_fnc(x)
@@ -78,7 +78,7 @@ def forward(self, x: spec.Tensor) -> spec.Tensor:
y = self.glu_linear(x)
x = x * y
- x = self.dropout(x)
+ x = F.dropout(x, dropout_rate, training=self.training)
x = self.linear2(x)
return x
@@ -88,8 +88,7 @@ class SelfAttention(nn.Module):
def __init__(self,
width: int,
- num_heads: int = 8,
- dropout_rate: float = 0.0) -> None:
+ num_heads: int = 8) -> None:
super().__init__()
self.width = width
@@ -104,7 +103,6 @@ def __init__(self,
self.query = nn.Linear(self.width, self.all_head_dim)
self.key = nn.Linear(self.width, self.all_head_dim)
self.value = nn.Linear(self.width, self.all_head_dim)
- self.dropout = nn.Dropout(dropout_rate)
self.out = nn.Linear(self.width, self.width)
self.reset_parameters()
@@ -120,7 +118,8 @@ def transpose_for_scores(self, x: spec.Tensor) -> spec.Tensor:
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
- def forward(self, x: spec.Tensor) -> spec.Tensor:
+ def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor:
+
mixed_query_layer = self.query(x)
key_layer = self.transpose_for_scores(self.key(x))
@@ -131,7 +130,7 @@ def forward(self, x: spec.Tensor) -> spec.Tensor:
attention_scores = attention_scores / math.sqrt(self.head_dim)
attention_probs = F.softmax(attention_scores, dim=-1)
- attention_probs = self.dropout(attention_probs)
+ attention_probs = F.dropout(attention_probs, dropout_rate, self.training)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
@@ -149,8 +148,7 @@ def __init__(self,
mlp_dim: Optional[int] = None,
num_heads: int = 12,
use_glu: bool = False,
- use_post_layer_norm: bool = False,
- dropout_rate: float = 0.0) -> None:
+ use_post_layer_norm: bool = False) -> None:
super().__init__()
self.width = width
@@ -161,35 +159,34 @@ def __init__(self,
self.layer_norm0 = nn.LayerNorm(self.width, eps=1e-6)
self.self_attention1 = SelfAttention(self.width, self.num_heads)
- self.dropout = nn.Dropout(dropout_rate)
self.layer_norm2 = nn.LayerNorm(self.width, eps=1e-6)
self.mlp3 = MlpBlock(
width=self.width,
mlp_dim=self.mlp_dim,
- use_glu=self.use_glu,
- dropout_rate=dropout_rate)
+ use_glu=self.use_glu)
+
+ def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor:
- def forward(self, x: spec.Tensor) -> spec.Tensor:
if not self.use_post_layer_norm:
y = self.layer_norm0(x)
- y = self.self_attention1(y)
- y = self.dropout(y)
+ y = self.self_attention1(y, dropout_rate)
+ y = F.dropout(y, dropout_rate, training=self.training)
x = x + y
y = self.layer_norm2(x)
- y = self.mlp3(y)
- y = self.dropout(y)
+ y = self.mlp3(y, dropout_rate)
+ y = F.dropout(y, dropout_rate, training=self.training)
x = x + y
else:
y = x
- y = self.self_attention1(y)
- y = self.dropout(y)
+ y = self.self_attention1(y, dropout_rate)
+ y = F.dropout(y, dropout_rate, training=self.training)
x = x + y
x = self.layer_norm0(x)
y = x
- y = self.mlp3(y)
- y = self.dropout(y)
+ y = self.mlp3(y, dropout_rate)
+ y = F.dropout(y, dropout_rate, training=self.training)
x = x + y
x = self.layer_norm2(x)
return x
@@ -204,8 +201,7 @@ def __init__(self,
mlp_dim: Optional[int] = None,
num_heads: int = 12,
use_glu: bool = False,
- use_post_layer_norm: bool = False,
- dropout_rate: float = 0.0) -> None:
+ use_post_layer_norm: bool = False) -> None:
super().__init__()
self.depth = depth
@@ -220,8 +216,7 @@ def __init__(self,
self.mlp_dim,
self.num_heads,
self.use_glu,
- self.use_post_layer_norm,
- dropout_rate) for _ in range(depth)
+ self.use_post_layer_norm) for _ in range(depth)
])
if not self.use_post_layer_norm:
@@ -229,10 +224,10 @@ def __init__(self,
else:
self.encoder_norm = None
- def forward(self, x: spec.Tensor) -> spec.Tensor:
+ def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor:
# Input Encoder.
for block in self.net:
- x = block(x)
+ x = block(x, dropout_rate)
if not self.use_post_layer_norm:
return self.encoder_norm(x)
else:
@@ -259,13 +254,13 @@ def __init__(self,
self.layer_norm = nn.LayerNorm(self.width, eps=1e-6)
self.mlp = MlpBlock(width=self.width, mlp_dim=self.mlp_dim)
- def forward(self, x: spec.Tensor) -> spec.Tensor:
+ def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor:
n, _, _ = x.shape
probe = torch.tile(self.probe, [n, 1, 1])
- x = self.mha(probe, x)[0]
+ x = self.mha(probe, x, dropout_rate=dropout_rate)[0]
y = self.layer_norm(x)
- x = x + self.mlp(y)
+ x = x + self.mlp(y, dropout_rate)
return x[:, 0]
@@ -285,15 +280,12 @@ def __init__(
mlp_dim: Optional[int] = None, # Defaults to 4x input dim.
num_heads: int = 12,
rep_size: Union[int, bool] = True,
- dropout_rate: Optional[float] = 0.0,
head_zeroinit: bool = True,
use_glu: bool = False,
use_post_layer_norm: bool = False,
use_map: bool = False,
dtype: Any = torch.float32) -> None:
super().__init__()
- if dropout_rate is None:
- dropout_rate = 0.0
self.num_classes = num_classes
self.patch_size = patch_size
@@ -318,7 +310,6 @@ def __init__(
self.patch_size,
stride=self.patch_size,
padding='valid')
- self.dropout = nn.Dropout(p=dropout_rate)
self.encoder = Encoder(
depth=self.depth,
@@ -326,8 +317,7 @@ def __init__(
mlp_dim=self.mlp_dim,
num_heads=self.num_heads,
use_glu=self.use_glu,
- use_post_layer_norm=self.use_post_layer_norm,
- dropout_rate=dropout_rate)
+ use_post_layer_norm=self.use_post_layer_norm)
if self.num_classes:
self.head = nn.Linear(self.width, self.num_classes)
@@ -355,7 +345,11 @@ def reset_parameters(self) -> None:
def get_posemb(self, x: spec.Tensor) -> spec.Tensor:
return posemb_sincos_2d(x).type(self.dtype)
- def forward(self, x: spec.Tensor) -> spec.Tensor:
+ def forward(
+ self,
+ x: spec.Tensor,
+ dropout_rate: float = DROPOUT_RATE) -> spec.Tensor:
+
# Patch extraction.
x = self.conv_patch_extract(x)
@@ -367,11 +361,11 @@ def forward(self, x: spec.Tensor) -> spec.Tensor:
x = torch.transpose(torch.reshape(x, (n, c, h * w)), 1, 2)
x = x + pes
- x = self.dropout(x)
- x = self.encoder(x)
+ x = F.dropout(x, dropout_rate, training=self.training)
+ x = self.encoder(x, dropout_rate)
if self.use_map:
- x = self.map(x)
+ x = self.map(x, dropout_rate)
else:
x = torch.mean(x, dim=1)
diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py
deleted file mode 100644
index 60e09edb5..000000000
--- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models_dropout.py
+++ /dev/null
@@ -1,378 +0,0 @@
-"""PyTorch implementation of refactored and simplified ViT.
-
-Adapted from:
-https://github.com/huggingface/transformers/tree/main/src/transformers/models/vit
-and https://github.com/lucidrains/vit-pytorch.
-"""
-
-import math
-from typing import Any, Optional, Tuple, Union
-
-import torch
-from torch import nn
-import torch.nn.functional as F
-
-from algoperf import init_utils
-from algoperf import spec
-from algoperf.workloads.wmt.wmt_pytorch.models_dropout import MultiheadAttention
-
-DROPOUT_RATE = 0.0
-
-
-def posemb_sincos_2d(patches: spec.Tensor, temperature=10_000.) -> spec.Tensor:
- """Follows the MoCo v3 logic."""
- _, width, h, w = patches.shape
- device = patches.device
- y, x = torch.meshgrid(torch.arange(h, device=device),
- torch.arange(w, device=device), indexing='ij')
-
- if width % 4 != 0:
- raise ValueError('Width must be mult of 4 for sincos posemb.')
- omega = torch.arange(width // 4, device=device) / (width // 4 - 1)
- omega = 1. / (temperature**omega)
- y = y.flatten()[:, None] * omega[None, :]
- x = x.flatten()[:, None] * omega[None, :]
- pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
- return pe[None, :, :]
-
-
-class MlpBlock(nn.Module):
- """Transformer MLP / feed-forward block."""
-
- def __init__(
- self,
- width: int,
- mlp_dim: Optional[int] = None, # Defaults to 4x input dim.
- use_glu: bool = False) -> None:
- super().__init__()
-
- self.width = width
- self.mlp_dim = mlp_dim or 4 * width
- self.use_glu = use_glu
-
- self.linear1 = nn.Linear(self.width, self.mlp_dim)
- self.act_fnc = nn.GELU(approximate='tanh')
-
- if self.use_glu:
- self.glu_linear = nn.Linear(self.mlp_dim, self.mlp_dim)
- else:
- self.glu_linear = None
-
- self.linear2 = nn.Linear(self.mlp_dim, self.width)
-
- self.reset_parameters()
-
- def reset_parameters(self) -> None:
- for module in self.modules():
- if isinstance(module, nn.Linear):
- nn.init.xavier_uniform_(module.weight.data)
- if module.bias is not None:
- module.bias.data.normal_(std=1e-6)
-
- def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor:
-
- x = self.linear1(x)
- x = self.act_fnc(x)
-
- if self.use_glu:
- y = self.glu_linear(x)
- x = x * y
-
- x = F.dropout(x, dropout_rate, training=self.training)
- x = self.linear2(x)
- return x
-
-
-class SelfAttention(nn.Module):
- """Self-attention special case of multi-head dot-product attention."""
-
- def __init__(self,
- width: int,
- num_heads: int = 8) -> None:
- super().__init__()
-
- self.width = width
- self.num_heads = num_heads
-
- assert width % num_heads == 0, (
- 'Memory dimension must be divisible by number of heads.')
-
- self.head_dim = int(width / num_heads)
- self.all_head_dim = self.num_heads * self.head_dim
-
- self.query = nn.Linear(self.width, self.all_head_dim)
- self.key = nn.Linear(self.width, self.all_head_dim)
- self.value = nn.Linear(self.width, self.all_head_dim)
- self.out = nn.Linear(self.width, self.width)
- self.reset_parameters()
-
- def reset_parameters(self) -> None:
- for module in self.modules():
- if isinstance(module, nn.Linear):
- nn.init.xavier_uniform_(module.weight.data)
- if module.bias is not None:
- nn.init.constant_(module.bias.data, 0.)
-
- def transpose_for_scores(self, x: spec.Tensor) -> spec.Tensor:
- new_x_shape = x.size()[:-1] + (self.num_heads, self.head_dim)
- x = x.view(new_x_shape)
- return x.permute(0, 2, 1, 3)
-
- def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor:
-
- mixed_query_layer = self.query(x)
-
- key_layer = self.transpose_for_scores(self.key(x))
- value_layer = self.transpose_for_scores(self.value(x))
- query_layer = self.transpose_for_scores(mixed_query_layer)
-
- attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
- attention_scores = attention_scores / math.sqrt(self.head_dim)
-
- attention_probs = F.softmax(attention_scores, dim=-1)
- attention_probs = F.dropout(attention_probs, dropout_rate, self.training)
-
- context_layer = torch.matmul(attention_probs, value_layer)
- context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
- new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_dim,)
- context_layer = context_layer.view(new_context_layer_shape)
- out = self.out(context_layer)
- return out
-
-
-class Encoder1DBlock(nn.Module):
- """Single transformer encoder block (MHSA + MLP)."""
-
- def __init__(self,
- width: int,
- mlp_dim: Optional[int] = None,
- num_heads: int = 12,
- use_glu: bool = False,
- use_post_layer_norm: bool = False) -> None:
- super().__init__()
-
- self.width = width
- self.mlp_dim = mlp_dim
- self.num_heads = num_heads
- self.use_glu = use_glu
- self.use_post_layer_norm = use_post_layer_norm
-
- self.layer_norm0 = nn.LayerNorm(self.width, eps=1e-6)
- self.self_attention1 = SelfAttention(self.width, self.num_heads)
- self.layer_norm2 = nn.LayerNorm(self.width, eps=1e-6)
- self.mlp3 = MlpBlock(
- width=self.width,
- mlp_dim=self.mlp_dim,
- use_glu=self.use_glu)
-
- def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor:
-
- if not self.use_post_layer_norm:
- y = self.layer_norm0(x)
- y = self.self_attention1(y, dropout_rate)
- y = F.dropout(y, dropout_rate, training=self.training)
- x = x + y
-
- y = self.layer_norm2(x)
- y = self.mlp3(y, dropout_rate)
- y = F.dropout(y, dropout_rate, training=self.training)
- x = x + y
- else:
- y = x
- y = self.self_attention1(y, dropout_rate)
- y = F.dropout(y, dropout_rate, training=self.training)
- x = x + y
- x = self.layer_norm0(x)
-
- y = x
- y = self.mlp3(y, dropout_rate)
- y = F.dropout(y, dropout_rate, training=self.training)
- x = x + y
- x = self.layer_norm2(x)
- return x
-
-
-class Encoder(nn.Module):
- """Transformer Model Encoder for sequence to sequence translation."""
-
- def __init__(self,
- depth: int,
- width: int,
- mlp_dim: Optional[int] = None,
- num_heads: int = 12,
- use_glu: bool = False,
- use_post_layer_norm: bool = False) -> None:
- super().__init__()
-
- self.depth = depth
- self.width = width
- self.mlp_dim = mlp_dim
- self.num_heads = num_heads
- self.use_glu = use_glu
- self.use_post_layer_norm = use_post_layer_norm
-
- self.net = nn.ModuleList([
- Encoder1DBlock(self.width,
- self.mlp_dim,
- self.num_heads,
- self.use_glu,
- self.use_post_layer_norm) for _ in range(depth)
- ])
-
- if not self.use_post_layer_norm:
- self.encoder_norm = nn.LayerNorm(self.width, eps=1e-6)
- else:
- self.encoder_norm = None
-
- def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor:
- # Input Encoder.
- for block in self.net:
- x = block(x, dropout_rate)
- if not self.use_post_layer_norm:
- return self.encoder_norm(x)
- else:
- return x
-
-
-class MAPHead(nn.Module):
- """Multihead Attention Pooling."""
-
- def __init__(self,
- width: int,
- mlp_dim: Optional[int] = None,
- num_heads: int = 12):
- super().__init__()
- self.width = width
- self.mlp_dim = mlp_dim
- self.num_heads = num_heads
-
- self.probe = nn.Parameter(torch.zeros((1, 1, self.width)))
- nn.init.xavier_uniform_(self.probe.data)
-
- self.mha = MultiheadAttention(
- self.width, num_heads=self.num_heads, self_attn=False, bias=True)
- self.layer_norm = nn.LayerNorm(self.width, eps=1e-6)
- self.mlp = MlpBlock(width=self.width, mlp_dim=self.mlp_dim)
-
- def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor:
- n, _, _ = x.shape
- probe = torch.tile(self.probe, [n, 1, 1])
-
- x = self.mha(probe, x, dropout_rate=dropout_rate)[0]
- y = self.layer_norm(x)
- x = x + self.mlp(y, dropout_rate)
- return x[:, 0]
-
-
-class ViT(nn.Module):
- """ViT model."""
-
- image_height: int = 224
- image_width: int = 224
- channels: int = 3
-
- def __init__(
- self,
- num_classes: int = 1000,
- patch_size: Tuple[int, int] = (16, 16),
- width: int = 768,
- depth: int = 12,
- mlp_dim: Optional[int] = None, # Defaults to 4x input dim.
- num_heads: int = 12,
- rep_size: Union[int, bool] = True,
- head_zeroinit: bool = True,
- use_glu: bool = False,
- use_post_layer_norm: bool = False,
- use_map: bool = False,
- dtype: Any = torch.float32) -> None:
- super().__init__()
-
- self.num_classes = num_classes
- self.patch_size = patch_size
- self.width = width
- self.depth = depth
- self.mlp_dim = mlp_dim
- self.num_heads = num_heads
- self.rep_size = rep_size
- self.head_zeroinit = head_zeroinit
- self.use_glu = use_glu
- self.use_post_layer_norm = use_post_layer_norm
- self.use_map = use_map
- self.dtype = dtype
-
- if self.rep_size:
- rep_size = self.width if self.rep_size is True else self.rep_size
- self.pre_logits = nn.Linear(self.width, rep_size)
-
- self.conv_patch_extract = nn.Conv2d(
- self.channels,
- self.width,
- self.patch_size,
- stride=self.patch_size,
- padding='valid')
-
- self.encoder = Encoder(
- depth=self.depth,
- width=self.width,
- mlp_dim=self.mlp_dim,
- num_heads=self.num_heads,
- use_glu=self.use_glu,
- use_post_layer_norm=self.use_post_layer_norm)
-
- if self.num_classes:
- self.head = nn.Linear(self.width, self.num_classes)
-
- if self.use_map:
- self.map = MAPHead(self.width, self.mlp_dim, self.num_heads)
- else:
- self.map = None
-
- self.reset_parameters()
-
- def reset_parameters(self) -> None:
- init_utils.pytorch_default_init(self.conv_patch_extract)
-
- if self.rep_size:
- init_utils.pytorch_default_init(self.pre_logits)
-
- if self.num_classes:
- if self.head_zeroinit:
- nn.init.constant_(self.head.weight.data, 0.)
- nn.init.constant_(self.head.bias.data, 0.)
- else:
- init_utils.pytorch_default_init(self.head)
-
- def get_posemb(self, x: spec.Tensor) -> spec.Tensor:
- return posemb_sincos_2d(x).type(self.dtype)
-
- def forward(
- self,
- x: spec.Tensor,
- dropout_rate: float = DROPOUT_RATE) -> spec.Tensor:
-
- # Patch extraction.
- x = self.conv_patch_extract(x)
-
- # Add posemb before adding extra token.
- n, c, h, w = x.shape
- pes = self.get_posemb(x)
-
- # Reshape to match Jax's ViT implementation.
- x = torch.transpose(torch.reshape(x, (n, c, h * w)), 1, 2)
- x = x + pes
-
- x = F.dropout(x, dropout_rate, training=self.training)
- x = self.encoder(x, dropout_rate)
-
- if self.use_map:
- x = self.map(x, dropout_rate)
- else:
- x = torch.mean(x, dim=1)
-
- if self.rep_size:
- x = torch.tanh(self.pre_logits(x))
-
- if self.num_classes:
- x = self.head(x)
-
- return x
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py
index db1e24521..a6a60bf95 100644
--- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py
+++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py
@@ -5,7 +5,7 @@
from dataclasses import dataclass
from functools import partial
import math
-from typing import Tuple
+from typing import Optional, Tuple
import torch
from torch import nn
@@ -17,6 +17,8 @@
from algoperf.workloads.librispeech_conformer.librispeech_pytorch.spectrum_augmenter import \
SpecAug
+DROPOUT_RATE = 0.1
+
@dataclass
class ConformerConfig:
@@ -26,10 +28,7 @@ class ConformerConfig:
num_attention_heads: int = 8
num_encoder_layers: int = 4
attention_dropout_rate: float = 0.0
- attention_residual_dropout_rate: float = 0.1
- conv_residual_dropout_rate: float = 0.0
feed_forward_dropout_rate: float = 0.0
- feed_forward_residual_dropout_rate: float = 0.1
convolution_kernel_size: int = 5
feed_forward_expansion_factor: int = 4
freq_mask_count: int = 2
@@ -39,7 +38,6 @@ class ConformerConfig:
time_mask_max_ratio: float = 0.05
time_masks_per_frame: float = 0.0
use_dynamic_time_mask_max_frames: bool = True
- input_dropout_rate: float = 0.1
batch_norm_momentum: float = 1 - 0.999
batch_norm_epsilon: float = 0.001
use_specaug: bool = True
@@ -77,11 +75,9 @@ class Subsample(nn.Module):
def __init__(self,
encoder_dim: int = 0,
- input_dropout_rate: float = 0.0,
num_bins: int = 80):
super().__init__()
self.encoder_dim = encoder_dim
- self.input_dropout_rate = input_dropout_rate
self.conv1 = Conv2dSubsampling(
input_channels=1, output_channels=encoder_dim)
@@ -93,9 +89,9 @@ def __init__(self,
out_features=self.encoder_dim,
bias=True)
self.pos_encode = AddPositionalEmbedding(embedding_dim=self.encoder_dim)
- self.dropout = nn.Dropout(p=self.input_dropout_rate, inplace=True)
- def forward(self, inputs, input_paddings):
+ def forward(self, inputs, input_paddings, dropout_rate):
+
output_paddings = input_paddings
outputs = inputs[:, None, :, :]
@@ -109,7 +105,7 @@ def forward(self, inputs, input_paddings):
outputs = self.linear(outputs)
outputs = outputs + self.pos_encode(seq_length=outputs.shape[1])
- outputs = self.dropout(outputs)
+ outputs = F.dropout(outputs, dropout_rate, training=self.training, inplace=True)
return outputs, output_paddings
@@ -201,15 +197,8 @@ def __init__(self, config: ConformerConfig):
out_features=config.encoder_dim,
bias=True)
- if config.feed_forward_residual_dropout_rate is None:
- feed_forward_residual_dropout_rate = 0.1
- else:
- feed_forward_residual_dropout_rate = (
- config.feed_forward_residual_dropout_rate)
- self.dropout2 = nn.Dropout(
- p=feed_forward_residual_dropout_rate, inplace=True)
+ def forward(self, inputs, padding_mask, dropout_rate):
- def forward(self, inputs, padding_mask):
inputs = self.ln(inputs)
inputs = self.linear1(inputs)
if self.config.activation_function_name == 'swish':
@@ -226,7 +215,7 @@ def forward(self, inputs, padding_mask):
inputs = inputs * padding_mask
inputs = self.linear2(inputs)
inputs = inputs * padding_mask
- inputs = self.dropout2(inputs)
+ inputs = F.dropout(inputs, dropout_rate, training=self.training, inplace=True)
return inputs
@@ -280,7 +269,7 @@ def __init__(self, config: ConformerConfig):
super().__init__()
self.embed_dim = config.encoder_dim
self.num_heads = config.num_attention_heads
- self.dropout = config.attention_dropout_rate
+ self.attention_dropout_rate = config.attention_dropout_rate
self.in_proj = nn.Linear(config.encoder_dim, 3 * config.encoder_dim)
self.out_proj = nn.Linear(config.encoder_dim, config.encoder_dim)
self.qs = QueryScaler(dim=config.encoder_dim // config.num_attention_heads)
@@ -297,7 +286,7 @@ def forward(self, inputs, key_padding_mask=None):
key=k,
value=v,
attn_mask=~key_padding_mask[:, None, None],
- dropout_p=self.dropout,
+ dropout_p=self.attention_dropout_rate,
).transpose(1, 2).reshape(batch_size, seq_len, embed_dim)
out = out * self.attention_temperature
out = self.out_proj(out)
@@ -313,19 +302,14 @@ def __init__(self, config: ConformerConfig):
self.ln = LayerNorm(dim=config.encoder_dim)
self.self_attention = MHSAwithQS(config)
- if config.attention_residual_dropout_rate is None:
- attention_residual_dropout_rate = 0.1
- else:
- attention_residual_dropout_rate = config.attention_residual_dropout_rate
- self.dropout = nn.Dropout(p=attention_residual_dropout_rate, inplace=True)
- def forward(self, outputs, paddings):
+ def forward(self, outputs, paddings, dropout_rate):
outputs = self.ln(outputs)
outputs = self.self_attention(
outputs,
key_padding_mask=paddings == 1,
)
- outputs = self.dropout(outputs)
+ outputs = F.dropout(outputs, dropout_rate, training=self.training, inplace=True)
return outputs
@@ -405,13 +389,8 @@ def __init__(self, config):
groups=config.encoder_dim)
self.bn = BatchNorm(config)
self.lin3 = nn.Linear(config.encoder_dim, config.encoder_dim)
- if config.conv_residual_dropout_rate is None:
- conv_residual_dropout_rate = 0.0
- else:
- conv_residual_dropout_rate = config.conv_residual_dropout_rate
- self.dropout = nn.Dropout(p=conv_residual_dropout_rate, inplace=True)
- def forward(self, inputs, input_paddings):
+ def forward(self, inputs, input_paddings, dropout_rate):
inputs = self.ln(inputs)
inputs = F.glu(torch.cat([self.lin1(inputs), self.lin2(inputs)], dim=2))
@@ -433,7 +412,7 @@ def forward(self, inputs, input_paddings):
inputs = activation_fn(inputs)
inputs = self.lin3(inputs)
- inputs = self.dropout(inputs)
+ inputs = F.dropout(inputs, dropout_rate, training=self.training, inplace=True)
return inputs
@@ -450,12 +429,12 @@ def __init__(self, config: ConformerConfig):
if config.use_post_layer_norm:
self.ln = LayerNorm(dim=config.encoder_dim)
- def forward(self, inputs, input_paddings):
+ def forward(self, inputs, input_paddings, dropout_rate):
padding_mask = 1 - input_paddings[:, :, None]
- inputs = inputs + 0.5 * self.ff1(inputs, padding_mask)
- inputs = inputs + self.mhsa(inputs, input_paddings)
- inputs = inputs + self.conv(inputs, input_paddings)
- inputs = inputs + 0.5 * self.ff2(inputs, padding_mask)
+ inputs = inputs + 0.5 * self.ff1(inputs, padding_mask, dropout_rate)
+ inputs = inputs + self.mhsa(inputs, input_paddings, dropout_rate)
+ inputs = inputs + self.conv(inputs, input_paddings, dropout_rate)
+ inputs = inputs + 0.5 * self.ff2(inputs, padding_mask, dropout_rate)
if self.ln:
inputs = self.ln(inputs)
return inputs
@@ -480,13 +459,8 @@ def __init__(self, config: ConformerConfig):
time_masks_per_frame=config.time_masks_per_frame,
use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames
)
- if config.input_dropout_rate is None:
- input_dropout_rate = 0.1
- else:
- input_dropout_rate = config.input_dropout_rate
self.subsample = Subsample(
encoder_dim=config.encoder_dim,
- input_dropout_rate=input_dropout_rate,
num_bins=preprocessing_config.num_bins)
self.conformers = nn.ModuleList(
[ConformerBlock(config) for _ in range(config.num_encoder_layers)])
@@ -494,15 +468,15 @@ def __init__(self, config: ConformerConfig):
self.ln = LayerNorm(config.encoder_dim)
self.lin = nn.Linear(config.encoder_dim, config.vocab_size)
- def forward(self, inputs, input_paddings):
+ def forward(self, inputs, input_paddings, dropout_rate=DROPOUT_RATE):
outputs = inputs
output_paddings = input_paddings
outputs, output_paddings = self.preprocessor(outputs, output_paddings)
if self.training and self.config.use_specaug:
outputs, output_paddings = self.specaug(outputs, output_paddings)
- outputs, output_paddings = self.subsample(outputs, output_paddings)
+ outputs, output_paddings = self.subsample(outputs, output_paddings, dropout_rate)
for conformer in self.conformers:
- outputs = conformer(outputs, output_paddings)
+ outputs = conformer(outputs, output_paddings, dropout_rate)
outputs = self.ln(outputs)
outputs = self.lin(outputs)
return outputs, output_paddings
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py
deleted file mode 100644
index a6a60bf95..000000000
--- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models_dropout.py
+++ /dev/null
@@ -1,482 +0,0 @@
-"""This is a pytorch implementation mirroring:
-https://github.com/google/init2winit/blob/master/init2winit/model_lib/conformer.py.
-"""
-
-from dataclasses import dataclass
-from functools import partial
-import math
-from typing import Optional, Tuple
-
-import torch
-from torch import nn
-from torch.nn import init
-import torch.nn.functional as F
-
-from algoperf.workloads.librispeech_conformer.librispeech_pytorch import \
- preprocessor
-from algoperf.workloads.librispeech_conformer.librispeech_pytorch.spectrum_augmenter import \
- SpecAug
-
-DROPOUT_RATE = 0.1
-
-
-@dataclass
-class ConformerConfig:
- """Global hyperparameters used to minimize obnoxious kwarg plumbing."""
- vocab_size: int = 1024
- encoder_dim: int = 512
- num_attention_heads: int = 8
- num_encoder_layers: int = 4
- attention_dropout_rate: float = 0.0
- feed_forward_dropout_rate: float = 0.0
- convolution_kernel_size: int = 5
- feed_forward_expansion_factor: int = 4
- freq_mask_count: int = 2
- freq_mask_max_bins: int = 27
- time_mask_count: int = 10
- time_mask_max_frames: int = 40
- time_mask_max_ratio: float = 0.05
- time_masks_per_frame: float = 0.0
- use_dynamic_time_mask_max_frames: bool = True
- batch_norm_momentum: float = 1 - 0.999
- batch_norm_epsilon: float = 0.001
- use_specaug: bool = True
- attention_temperature: float = 1.0
- activation_function_name: str = 'swish'
- use_post_layer_norm: bool = True
-
-
-def initialize(m):
- if isinstance(m, nn.Linear) or isinstance(m, nn.Conv1d):
- init.xavier_uniform_(m.weight)
- if m.bias is not None:
- init.constant_(m.bias, 0)
- elif isinstance(m, nn.MultiheadAttention):
- init.xavier_uniform_(m.in_proj_weight)
- for i in m.children():
- initialize(i)
-
-
-class LayerNorm(nn.Module):
-
- def __init__(self, dim, epsilon=1e-6):
- super().__init__()
- self.dim = dim
-
- self.scale = nn.Parameter(torch.zeros(self.dim))
- self.bias = nn.Parameter(torch.zeros(self.dim))
- self.epsilon = epsilon
-
- def forward(self, x):
- return F.layer_norm(x, (self.dim,), 1 + self.scale, self.bias, self.epsilon)
-
-
-class Subsample(nn.Module):
-
- def __init__(self,
- encoder_dim: int = 0,
- num_bins: int = 80):
- super().__init__()
- self.encoder_dim = encoder_dim
-
- self.conv1 = Conv2dSubsampling(
- input_channels=1, output_channels=encoder_dim)
- self.conv2 = Conv2dSubsampling(
- input_channels=encoder_dim, output_channels=encoder_dim)
-
- self.linear = nn.Linear(
- in_features=self.encoder_dim * num_bins // 4,
- out_features=self.encoder_dim,
- bias=True)
- self.pos_encode = AddPositionalEmbedding(embedding_dim=self.encoder_dim)
-
- def forward(self, inputs, input_paddings, dropout_rate):
-
- output_paddings = input_paddings
- outputs = inputs[:, None, :, :]
-
- outputs, output_paddings = self.conv1(outputs, output_paddings)
- outputs, output_paddings = self.conv2(outputs, output_paddings)
-
- batch_size, channels, subsampled_lengths, subsampled_dims = outputs.shape
- outputs = outputs.permute(0, 2, 3, 1).reshape(batch_size,
- subsampled_lengths,
- subsampled_dims * channels)
-
- outputs = self.linear(outputs)
- outputs = outputs + self.pos_encode(seq_length=outputs.shape[1])
- outputs = F.dropout(outputs, dropout_rate, training=self.training, inplace=True)
-
- return outputs, output_paddings
-
-
-class Conv2dSubsampling(nn.Module):
-
- def __init__(self,
- input_channels: int,
- output_channels: int,
- filter_stride: Tuple[int] = (2, 2),
- padding: str = 'SAME'):
- super().__init__()
-
- self.input_channels = input_channels
- self.output_channels = output_channels
- self.filter_stride = filter_stride
- self.padding = padding
-
- self.filter_shape = (output_channels, input_channels, 3, 3)
-
- self.kernel = nn.Parameter(
- torch.nn.init.xavier_uniform_(torch.empty(*self.filter_shape)))
- self.bias = nn.Parameter(torch.zeros(output_channels))
- self.register_buffer('paddings_kernel', torch.ones([1, 1, 1]))
-
- def get_same_padding(self, input_shape):
- in_height, in_width = input_shape[2:]
- stride_height, stride_width = self.filter_stride
- filter_height, filter_width = 3, 3
- if in_height % stride_height == 0:
- pad_along_height = max(filter_height - stride_height, 0)
- else:
- pad_along_height = max(filter_height - (in_height % stride_height), 0)
- if in_width % stride_width == 0:
- pad_along_width = max(filter_width - stride_width, 0)
- else:
- pad_along_width = max(filter_width - (in_width % stride_width), 0)
- pad_top = pad_along_height // 2
- pad_bottom = pad_along_height - pad_top
- pad_left = pad_along_width // 2
- pad_right = pad_along_width - pad_left
- return (pad_left, pad_right, pad_top, pad_bottom)
-
- def forward(self, inputs, paddings):
- groups = inputs.shape[1] // self.input_channels
-
- if self.padding == 'SAME':
- in_ = F.pad(inputs, self.get_same_padding(inputs.shape))
- else:
- in_ = inputs
- outputs = F.conv2d(
- input=in_,
- weight=self.kernel,
- bias=self.bias,
- stride=self.filter_stride,
- dilation=(1, 1),
- groups=groups)
-
- outputs = F.relu(outputs)
-
- input_length = paddings.shape[1]
- stride = self.filter_stride[0]
- pad_len = (input_length + stride - 1) // stride * stride - input_length
- padded_paddings = F.pad(
- paddings[:, None, :], (0, pad_len), mode='constant', value=0)
- out_padding = F.conv1d(
- input=padded_paddings,
- weight=self.paddings_kernel,
- stride=self.filter_stride[:1])
- out_padding = out_padding.squeeze(dim=1)
- outputs = outputs * (1 - out_padding[:, None, :, None])
- return outputs, out_padding
-
-
-class FeedForwardModule(nn.Module):
-
- def __init__(self, config: ConformerConfig):
- super().__init__()
- self.config = config
-
- self.ln = LayerNorm(dim=config.encoder_dim)
- self.linear1 = nn.Linear(
- in_features=config.encoder_dim,
- out_features=config.encoder_dim * config.feed_forward_expansion_factor,
- bias=True)
- self.dropout1 = nn.Dropout(p=config.feed_forward_dropout_rate, inplace=True)
- self.linear2 = nn.Linear(
- in_features=config.encoder_dim * config.feed_forward_expansion_factor,
- out_features=config.encoder_dim,
- bias=True)
-
- def forward(self, inputs, padding_mask, dropout_rate):
-
- inputs = self.ln(inputs)
- inputs = self.linear1(inputs)
- if self.config.activation_function_name == 'swish':
- activation_fn = F.silu
- elif self.config.activation_function_name == 'gelu':
- # Use tanh approximation of GELU which is default for jax
- activation_fn = partial(F.gelu, approximate='tanh')
- else:
- raise ValueError('Only "swish" and "gelu" are supported '
- 'config.activation_function_name values, recieved '
- f'{self.config.activation_function_name}')
- inputs = activation_fn(inputs)
- inputs = self.dropout1(inputs)
- inputs = inputs * padding_mask
- inputs = self.linear2(inputs)
- inputs = inputs * padding_mask
- inputs = F.dropout(inputs, dropout_rate, training=self.training, inplace=True)
-
- return inputs
-
-
-class AddPositionalEmbedding(nn.Module):
-
- def __init__(self,
- min_timescale: int = 1,
- max_timescale: int = 10_000,
- embedding_dim: int = 512):
- super().__init__()
- self.min_timescale = min_timescale
- self.max_timescale = max_timescale
- self.embedding_dim = embedding_dim
- num_timescales = self.embedding_dim // 2
- log_timescale_increment = math.log(
- float(self.max_timescale) / float(self.min_timescale)) / (
- num_timescales - 1)
- inv_timescales = self.min_timescale * \
- torch.exp(torch.arange(num_timescales, dtype=torch.float32)
- * -log_timescale_increment)
- self.register_buffer('inv_timescales', inv_timescales[None, None, :])
-
- def forward(self, seq_length):
- position = torch.arange(
- end=seq_length, dtype=torch.float32, device=self.inv_timescales.device)
- scaled_time = position[None, :, None] * self.inv_timescales
- signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2)
- if self.embedding_dim % 2:
- signal = torch.cat(
- [signal, torch.zeros(signal.shape[0], signal.shape[1], 1)], dim=2)
- return signal
-
-
-class QueryScaler(nn.Module):
-
- def __init__(self, dim):
- super().__init__()
- self.dim = dim
- self.scale = nn.Parameter(torch.zeros(self.dim))
-
- def forward(self, inputs):
- r_softplus_0 = 1.442695041
- scale = r_softplus_0 * F.softplus(self.scale)
- return inputs * scale
-
-
-class MHSAwithQS(nn.Module):
-
- def __init__(self, config: ConformerConfig):
- super().__init__()
- self.embed_dim = config.encoder_dim
- self.num_heads = config.num_attention_heads
- self.attention_dropout_rate = config.attention_dropout_rate
- self.in_proj = nn.Linear(config.encoder_dim, 3 * config.encoder_dim)
- self.out_proj = nn.Linear(config.encoder_dim, config.encoder_dim)
- self.qs = QueryScaler(dim=config.encoder_dim // config.num_attention_heads)
- self.attention_temperature = config.attention_temperature
-
- def forward(self, inputs, key_padding_mask=None):
- batch_size, seq_len, embed_dim = inputs.shape
- q, k, v = self.in_proj(inputs).split(self.embed_dim, dim=2)
- q = self.qs(q.view(batch_size, seq_len, self.num_heads, -1)).transpose(1, 2)
- k = k.view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)
- v = v.view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)
- out = F.scaled_dot_product_attention(
- query=q,
- key=k,
- value=v,
- attn_mask=~key_padding_mask[:, None, None],
- dropout_p=self.attention_dropout_rate,
- ).transpose(1, 2).reshape(batch_size, seq_len, embed_dim)
- out = out * self.attention_temperature
- out = self.out_proj(out)
- return out
-
-
-class MultiHeadedSelfAttention(nn.Module):
-
- def __init__(self, config: ConformerConfig):
- super().__init__()
-
- self.config = config
-
- self.ln = LayerNorm(dim=config.encoder_dim)
- self.self_attention = MHSAwithQS(config)
-
- def forward(self, outputs, paddings, dropout_rate):
- outputs = self.ln(outputs)
- outputs = self.self_attention(
- outputs,
- key_padding_mask=paddings == 1,
- )
- outputs = F.dropout(outputs, dropout_rate, training=self.training, inplace=True)
- return outputs
-
-
-class BatchNorm(nn.Module):
-
- def __init__(self, config: ConformerConfig):
- super().__init__()
- running_mean = torch.zeros(config.encoder_dim)
- running_var = torch.ones(config.encoder_dim)
- self.register_buffer('running_mean', running_mean)
- self.register_buffer('running_var', running_var)
- self.scale = nn.Parameter(torch.zeros(config.encoder_dim))
- self.bias = nn.Parameter(torch.zeros(config.encoder_dim))
-
- self.register_buffer('dim', torch.FloatTensor([config.encoder_dim]))
- self.momentum = config.batch_norm_momentum
- self.epsilon = config.batch_norm_epsilon
-
- def forward(self, inputs, input_paddings):
- #inputs: NHD
- #padding: NH
- """
- Alternatively:
- inputs[input_paddings==0] = F.batch_norm(
- input = inputs[input_paddings==0],
- running_mean = self.running_mean,
- running_var = self.running_var,
- weight = 1+self.scale,
- bias = self.bias,
- training = self.training,
- momentum=1-self.momentum,
- eps=self.epsilon
- )
- inputs.masked_fill(input_paddings[...,None] != 0, 0)
- return inputs
- """
- mask = 1 - input_paddings[:, :, None]
- if self.training:
- count = mask.sum()
- masked_inp = inputs.masked_fill(mask == 0, 0)
- mean = (masked_inp).sum(dim=(0, 1)) / count
- var = (torch.square(masked_inp - mean) * mask).sum(dim=(0, 1)) / count
-
- self.running_mean = (1 - self.momentum) * self.running_mean + (
- self.momentum) * mean.detach()
- self.running_var = (1 - self.momentum) * self.running_var + (
- self.momentum) * var.detach()
-
- else:
- mean = self.running_mean
- var = self.running_var
- v = (1 + self.scale) * torch.rsqrt(var + self.epsilon)
- bn = (inputs - mean) * v + self.bias
- output = bn.masked_fill(mask == 0, 0)
- return output
-
-
-class ConvolutionBlock(nn.Module):
-
- def __init__(self, config):
- super().__init__()
-
- self.config = config
- self.ln = LayerNorm(dim=config.encoder_dim)
- self.lin1 = nn.Linear(
- in_features=config.encoder_dim, out_features=config.encoder_dim)
- self.lin2 = nn.Linear(
- in_features=config.encoder_dim, out_features=config.encoder_dim)
-
- self.conv1 = nn.Conv1d(
- in_channels=config.encoder_dim,
- out_channels=config.encoder_dim,
- kernel_size=(config.convolution_kernel_size,),
- stride=(1,),
- padding='same',
- bias=False,
- groups=config.encoder_dim)
- self.bn = BatchNorm(config)
- self.lin3 = nn.Linear(config.encoder_dim, config.encoder_dim)
-
- def forward(self, inputs, input_paddings, dropout_rate):
- inputs = self.ln(inputs)
-
- inputs = F.glu(torch.cat([self.lin1(inputs), self.lin2(inputs)], dim=2))
- inputs = inputs * (1 - input_paddings[:, :, None])
-
- inputs = inputs.permute(0, 2, 1)
- inputs = self.conv1(inputs)
- inputs = inputs.permute(0, 2, 1)
-
- inputs = self.bn(inputs, input_paddings)
- if self.config.activation_function_name == 'swish':
- activation_fn = F.silu
- elif self.config.activation_function_name == 'gelu':
- activation_fn = F.gelu
- else:
- raise ValueError('Only "swish" and "gelu" are supported '
- 'config.activation_function_name values, recieved '
- f'{self.config.activation_function_name}')
- inputs = activation_fn(inputs)
- inputs = self.lin3(inputs)
-
- inputs = F.dropout(inputs, dropout_rate, training=self.training, inplace=True)
- return inputs
-
-
-class ConformerBlock(nn.Module):
-
- def __init__(self, config: ConformerConfig):
- super().__init__()
-
- self.ff1 = FeedForwardModule(config)
- self.mhsa = MultiHeadedSelfAttention(config)
- self.conv = ConvolutionBlock(config)
- self.ff2 = FeedForwardModule(config)
- self.ln = None
- if config.use_post_layer_norm:
- self.ln = LayerNorm(dim=config.encoder_dim)
-
- def forward(self, inputs, input_paddings, dropout_rate):
- padding_mask = 1 - input_paddings[:, :, None]
- inputs = inputs + 0.5 * self.ff1(inputs, padding_mask, dropout_rate)
- inputs = inputs + self.mhsa(inputs, input_paddings, dropout_rate)
- inputs = inputs + self.conv(inputs, input_paddings, dropout_rate)
- inputs = inputs + 0.5 * self.ff2(inputs, padding_mask, dropout_rate)
- if self.ln:
- inputs = self.ln(inputs)
- return inputs
-
-
-class ConformerEncoderDecoder(nn.Module):
-
- def __init__(self, config: ConformerConfig):
- super().__init__()
- self.config = config
- preprocessing_config = preprocessor.PreprocessorConfig()
- self.preprocessor = preprocessor.MelFilterbankFrontend(
- preprocessing_config,
- per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR,
- per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR)
- self.specaug = SpecAug(
- freq_mask_count=config.freq_mask_count,
- freq_mask_max_bins=config.freq_mask_max_bins,
- time_mask_count=config.time_mask_count,
- time_mask_max_frames=config.time_mask_max_frames,
- time_mask_max_ratio=config.time_mask_max_ratio,
- time_masks_per_frame=config.time_masks_per_frame,
- use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames
- )
- self.subsample = Subsample(
- encoder_dim=config.encoder_dim,
- num_bins=preprocessing_config.num_bins)
- self.conformers = nn.ModuleList(
- [ConformerBlock(config) for _ in range(config.num_encoder_layers)])
-
- self.ln = LayerNorm(config.encoder_dim)
- self.lin = nn.Linear(config.encoder_dim, config.vocab_size)
-
- def forward(self, inputs, input_paddings, dropout_rate=DROPOUT_RATE):
- outputs = inputs
- output_paddings = input_paddings
- outputs, output_paddings = self.preprocessor(outputs, output_paddings)
- if self.training and self.config.use_specaug:
- outputs, output_paddings = self.specaug(outputs, output_paddings)
- outputs, output_paddings = self.subsample(outputs, output_paddings, dropout_rate)
- for conformer in self.conformers:
- outputs = conformer(outputs, output_paddings, dropout_rate)
- outputs = self.ln(outputs)
- outputs = self.lin(outputs)
- return outputs, output_paddings
diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py
index 84d317326..a8480a343 100644
--- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py
+++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py
@@ -17,6 +17,7 @@
SpecAug
USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ
+DROPOUT_RATE = 0.1
@dataclass
@@ -38,10 +39,6 @@ class DeepspeechConfig:
use_dynamic_time_mask_max_frames: bool = True
batch_norm_momentum: float = 1 - 0.999
batch_norm_epsilon: float = 0.001
- # If None, defaults to 0.1.
- input_dropout_rate: Optional[float] = 0.1
- # If None, defaults to 0.1.
- feed_forward_dropout_rate: Optional[float] = 0.1
enable_residual_connections: bool = True
enable_decoder_layer_norm: bool = True
bidirectional: bool = True
@@ -87,13 +84,8 @@ def __init__(self, config: DeepspeechConfig):
self.lin = nn.LazyLinear(out_features=self.encoder_dim, bias=True)
- if config.input_dropout_rate is None:
- input_dropout_rate = 0.1
- else:
- input_dropout_rate = config.input_dropout_rate
- self.dropout = nn.Dropout(p=input_dropout_rate)
+ def forward(self, inputs, input_paddings, dropout_rate):
- def forward(self, inputs, input_paddings):
output_paddings = input_paddings
outputs = inputs[:, None, :, :]
@@ -106,7 +98,7 @@ def forward(self, inputs, input_paddings):
subsampled_dims * channels)
outputs = self.lin(outputs)
- outputs = self.dropout(outputs)
+ outputs = F.dropout(outputs, dropout_rate, training=self.training)
return outputs, output_paddings
@@ -205,13 +197,9 @@ def __init__(self, config: DeepspeechConfig):
batch_norm_momentum=config.batch_norm_momentum,
batch_norm_epsilon=config.batch_norm_epsilon)
self.lin = nn.LazyLinear(out_features=config.encoder_dim, bias=True)
- if config.feed_forward_dropout_rate is None:
- feed_forward_dropout_rate = 0.1
- else:
- feed_forward_dropout_rate = config.feed_forward_dropout_rate
- self.dropout = nn.Dropout(p=feed_forward_dropout_rate)
- def forward(self, inputs, input_paddings):
+ def forward(self, inputs, input_paddings, dropout_rate):
+
padding_mask = (1 - input_paddings)[:, :, None]
if self.config.layernorm_everywhere:
inputs = self.normalization_layer(inputs)
@@ -226,7 +214,7 @@ def forward(self, inputs, input_paddings):
inputs = F.relu(inputs)
inputs = inputs * padding_mask
- inputs = self.dropout(inputs)
+ inputs = F.dropout(inputs, dropout_rate, training=self.training)
return inputs
@@ -363,14 +351,14 @@ def __init__(self, config: DeepspeechConfig):
self.lin = nn.Linear(config.encoder_dim, config.vocab_size)
- def forward(self, inputs, input_paddings):
+ def forward(self, inputs, input_paddings, dropout_rate=DROPOUT_RATE):
outputs = inputs
output_paddings = input_paddings
outputs, output_paddings = self.preprocessor(outputs, output_paddings)
if self.training and self.config.use_specaug:
outputs, output_paddings = self.specaug(outputs, output_paddings)
- outputs, output_paddings = self.subsample(outputs, output_paddings)
+ outputs, output_paddings = self.subsample(outputs, output_paddings, dropout_rate)
for idx in range(self.config.num_lstm_layers):
if self.config.enable_residual_connections:
outputs = outputs + self.lstms[idx](outputs, output_paddings)
@@ -379,9 +367,9 @@ def forward(self, inputs, input_paddings):
for idx in range(self.config.num_ffn_layers):
if self.config.enable_residual_connections:
- outputs = outputs + self.ffns[idx](outputs, output_paddings)
+ outputs = outputs + self.ffns[idx](outputs, output_paddings, dropout_rate)
else:
- outputs = self.ffns[idx](outputs, output_paddings)
+ outputs = self.ffns[idx](outputs, output_paddings, dropout_rate)
if self.config.enable_decoder_layer_norm:
outputs = self.ln(outputs)
diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py
deleted file mode 100644
index a8480a343..000000000
--- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models_dropout.py
+++ /dev/null
@@ -1,379 +0,0 @@
-"""This is a pytorch implementation mirroring:
-https://github.com/google/init2winit/blob/master/init2winit/model_lib/conformer.py.
-"""
-
-from dataclasses import dataclass
-import os
-from typing import Optional, Tuple
-
-import torch
-from torch import nn
-import torch.distributed.nn as dist_nn
-import torch.nn.functional as F
-
-from algoperf.workloads.librispeech_conformer.librispeech_pytorch import \
- preprocessor
-from algoperf.workloads.librispeech_conformer.librispeech_pytorch.spectrum_augmenter import \
- SpecAug
-
-USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ
-DROPOUT_RATE = 0.1
-
-
-@dataclass
-class DeepspeechConfig:
- """Global hyperparameters used to minimize obnoxious kwarg plumbing."""
- vocab_size: int = 1024
- encoder_dim: int = 512
- num_lstm_layers: int = 6
- num_ffn_layers: int = 3
- conv_subsampling_factor: int = 2
- conv_subsampling_layers: int = 2
- use_specaug: bool = True
- freq_mask_count: int = 2
- freq_mask_max_bins: int = 27
- time_mask_count: int = 10
- time_mask_max_frames: int = 40
- time_mask_max_ratio: float = 0.05
- time_masks_per_frame: float = 0.0
- use_dynamic_time_mask_max_frames: bool = True
- batch_norm_momentum: float = 1 - 0.999
- batch_norm_epsilon: float = 0.001
- enable_residual_connections: bool = True
- enable_decoder_layer_norm: bool = True
- bidirectional: bool = True
- use_tanh: bool = False
- layernorm_everywhere: bool = False
-
-
-class LayerNorm(nn.Module):
-
- def __init__(self, dim, epsilon=1e-6):
- super().__init__()
- self.dim = dim
-
- self.scale = nn.Parameter(torch.zeros(self.dim))
- self.bias = nn.Parameter(torch.zeros(self.dim))
- self.epsilon = epsilon
-
- def forward(self, x):
- mean = x.mean(dim=-1, keepdims=True)
- var = x.var(dim=-1, unbiased=False, keepdims=True)
-
- normed_x = (x - mean) * torch.rsqrt(var + self.epsilon)
- normed_x *= (1 + self.scale)
- normed_x += self.bias
-
- return normed_x
-
-
-class Subsample(nn.Module):
-
- def __init__(self, config: DeepspeechConfig):
- super().__init__()
- encoder_dim = config.encoder_dim
-
- self.encoder_dim = encoder_dim
-
- self.conv1 = Conv2dSubsampling(
- input_channels=1, output_channels=encoder_dim, use_tanh=config.use_tanh)
- self.conv2 = Conv2dSubsampling(
- input_channels=encoder_dim,
- output_channels=encoder_dim,
- use_tanh=config.use_tanh)
-
- self.lin = nn.LazyLinear(out_features=self.encoder_dim, bias=True)
-
- def forward(self, inputs, input_paddings, dropout_rate):
-
- output_paddings = input_paddings
- outputs = inputs[:, None, :, :]
-
- outputs, output_paddings = self.conv1(outputs, output_paddings)
- outputs, output_paddings = self.conv2(outputs, output_paddings)
-
- batch_size, channels, subsampled_lengths, subsampled_dims = outputs.shape
- outputs = outputs.permute(0, 2, 3, 1).reshape(batch_size,
- subsampled_lengths,
- subsampled_dims * channels)
-
- outputs = self.lin(outputs)
- outputs = F.dropout(outputs, dropout_rate, training=self.training)
-
- return outputs, output_paddings
-
-
-class Conv2dSubsampling(nn.Module):
-
- def __init__(self,
- input_channels: int,
- output_channels: int,
- filter_stride: Tuple[int] = (2, 2),
- padding: str = 'SAME',
- batch_norm_momentum: float = 0.999,
- batch_norm_epsilon: float = 0.001,
- use_tanh: bool = False):
- super().__init__()
-
- self.input_channels = input_channels
- self.output_channels = output_channels
- self.filter_stride = filter_stride
- self.padding = padding
-
- self.filter_shape = (output_channels, input_channels, 3, 3)
-
- self.kernel = nn.Parameter(
- nn.init.xavier_uniform_(torch.empty(*self.filter_shape)))
- self.bias = nn.Parameter(torch.zeros(output_channels))
-
- self.use_tanh = use_tanh
-
- def get_same_padding(self, input_shape):
- in_height, in_width = input_shape[2:]
- stride_height, stride_width = self.filter_stride
- filter_height, filter_width = 3, 3
- if in_height % stride_height == 0:
- pad_along_height = max(filter_height - stride_height, 0)
- else:
- pad_along_height = max(filter_height - (in_height % stride_height), 0)
- if in_width % stride_width == 0:
- pad_along_width = max(filter_width - stride_width, 0)
- else:
- pad_along_width = max(filter_width - (in_width % stride_width), 0)
- pad_top = pad_along_height // 2
- pad_bottom = pad_along_height - pad_top
- pad_left = pad_along_width // 2
- pad_right = pad_along_width - pad_left
- return (pad_left, pad_right, pad_top, pad_bottom)
-
- def forward(self, inputs, paddings):
- groups = inputs.shape[1] // self.input_channels
-
- if self.padding == 'SAME':
- in_ = F.pad(inputs, self.get_same_padding(inputs.shape))
- else:
- in_ = inputs
- outputs = F.conv2d(
- input=in_,
- weight=self.kernel,
- bias=self.bias,
- stride=self.filter_stride,
- dilation=(1, 1),
- groups=groups)
-
- if self.use_tanh:
- outputs = F.tanh(outputs)
- else:
- outputs = F.relu(outputs)
-
- input_length = paddings.shape[1]
- stride = self.filter_stride[0]
- pad_len = (input_length + stride - 1) // stride * stride - input_length
- out_padding = F.conv1d(
- input=torch.cat([
- paddings[:, None, :],
- torch.zeros(
- size=(paddings.shape[0], 1, pad_len), device=paddings.device)
- ],
- dim=2),
- weight=torch.ones([1, 1, 1], device=paddings.device),
- stride=self.filter_stride[:1])
- out_padding = out_padding.squeeze(dim=1)
- outputs = outputs * (1 - out_padding[:, None, :, None])
- return outputs, out_padding
-
-
-class FeedForwardModule(nn.Module):
-
- def __init__(self, config: DeepspeechConfig):
- super().__init__()
- self.config = config
-
- if config.layernorm_everywhere:
- self.normalization_layer = LayerNorm(config.encoder_dim)
- else:
- self.bn_normalization_layer = BatchNorm(
- dim=config.encoder_dim,
- batch_norm_momentum=config.batch_norm_momentum,
- batch_norm_epsilon=config.batch_norm_epsilon)
- self.lin = nn.LazyLinear(out_features=config.encoder_dim, bias=True)
-
- def forward(self, inputs, input_paddings, dropout_rate):
-
- padding_mask = (1 - input_paddings)[:, :, None]
- if self.config.layernorm_everywhere:
- inputs = self.normalization_layer(inputs)
- else: # batchnorm
- inputs = self.bn_normalization_layer(inputs, input_paddings)
-
- inputs = self.lin(inputs)
-
- if self.config.use_tanh:
- inputs = F.tanh(inputs)
- else:
- inputs = F.relu(inputs)
-
- inputs = inputs * padding_mask
- inputs = F.dropout(inputs, dropout_rate, training=self.training)
-
- return inputs
-
-
-class BatchNorm(nn.Module):
-
- def __init__(self, dim, batch_norm_momentum, batch_norm_epsilon):
- super().__init__()
- running_mean = torch.zeros(dim)
- running_var = torch.ones(dim)
- self.register_buffer('running_mean', running_mean)
- self.register_buffer('running_var', running_var)
- self.weight = nn.Parameter(torch.zeros(dim))
- self.bias = nn.Parameter(torch.zeros(dim))
-
- self.momentum = batch_norm_momentum
- self.epsilon = batch_norm_epsilon
- self.dim = dim
-
- def forward(self, inputs, input_paddings):
- #inputs: NHD
- #padding: NH
- mask = 1 - input_paddings[:, :, None]
- if self.training:
- count = mask.sum()
- masked_inp = inputs.masked_fill(mask == 0, 0)
- sum_ = (masked_inp).sum(dim=(0, 1))
- if USE_PYTORCH_DDP:
- sum_ = dist_nn.all_reduce(sum_)
- count = dist_nn.all_reduce(count)
- mean = sum_ / count
-
- sum_ = (torch.square(masked_inp - mean) * mask).sum(dim=(0, 1))
- if USE_PYTORCH_DDP:
- sum_ = dist_nn.all_reduce(sum_)
- var = sum_ / count
-
- self.running_mean = (1 - self.momentum) * self.running_mean + (
- self.momentum) * mean.detach()
- self.running_var = (1 - self.momentum) * self.running_var + (
- self.momentum) * var.detach()
- else:
- mean = self.running_mean
- var = self.running_var
- v = (1 + self.weight) * torch.rsqrt(var + self.epsilon)
- bn = (inputs - mean) * v + self.bias
- output = bn.masked_fill(mask == 0, 0)
- return output
-
-
-class BatchRNN(nn.Module):
-
- def __init__(self, config: DeepspeechConfig):
- super().__init__()
- self.config = config
- hidden_size = config.encoder_dim
- input_size = config.encoder_dim
- bidirectional = config.bidirectional
- self.bidirectional = bidirectional
-
- if config.layernorm_everywhere:
- self.normalization_layer = LayerNorm(config.encoder_dim)
- else:
- self.bn_normalization_layer = BatchNorm(config.encoder_dim,
- config.batch_norm_momentum,
- config.batch_norm_epsilon)
-
- if bidirectional:
- self.lstm = nn.LSTM(
- input_size=input_size,
- hidden_size=hidden_size // 2,
- bidirectional=True,
- batch_first=True)
- else:
- self.lstm = nn.LSTM(
- input_size=input_size, hidden_size=hidden_size, batch_first=True)
-
- def forward(self, inputs, input_paddings):
- if self.config.layernorm_everywhere:
- inputs = self.normalization_layer(inputs)
- else:
- inputs = self.bn_normalization_layer(inputs, input_paddings)
- lengths = torch.sum(1 - input_paddings, dim=1).detach().cpu().numpy()
- packed_inputs = torch.nn.utils.rnn.pack_padded_sequence(
- inputs, lengths, batch_first=True, enforce_sorted=False)
- packed_outputs, _ = self.lstm(packed_inputs)
- outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(
- packed_outputs, batch_first=True)
- if outputs.shape[1] < inputs.shape[1]:
- outputs = torch.cat([
- outputs,
- torch.zeros(
- size=(outputs.shape[0],
- inputs.shape[1] - outputs.shape[1],
- outputs.shape[2]),
- device=outputs.device)
- ],
- dim=1)
- return outputs
-
-
-class DeepspeechEncoderDecoder(nn.Module):
-
- def __init__(self, config: DeepspeechConfig):
- super().__init__()
- self.config = config
-
- self.specaug = SpecAug(
- freq_mask_count=config.freq_mask_count,
- freq_mask_max_bins=config.freq_mask_max_bins,
- time_mask_count=config.time_mask_count,
- time_mask_max_frames=config.time_mask_max_frames,
- time_mask_max_ratio=config.time_mask_max_ratio,
- time_masks_per_frame=config.time_masks_per_frame,
- use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames
- )
- preprocessing_config = preprocessor.PreprocessorConfig()
- self.preprocessor = preprocessor.MelFilterbankFrontend(
- preprocessing_config,
- per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR,
- per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR)
-
- self.subsample = Subsample(config=config)
-
- self.lstms = nn.ModuleList(
- [BatchRNN(config) for _ in range(config.num_lstm_layers)])
- self.ffns = nn.ModuleList(
- [FeedForwardModule(config) for _ in range(config.num_ffn_layers)])
-
- if config.enable_decoder_layer_norm:
- self.ln = LayerNorm(config.encoder_dim)
- else:
- self.ln = nn.Identity()
-
- self.lin = nn.Linear(config.encoder_dim, config.vocab_size)
-
- def forward(self, inputs, input_paddings, dropout_rate=DROPOUT_RATE):
- outputs = inputs
- output_paddings = input_paddings
-
- outputs, output_paddings = self.preprocessor(outputs, output_paddings)
- if self.training and self.config.use_specaug:
- outputs, output_paddings = self.specaug(outputs, output_paddings)
- outputs, output_paddings = self.subsample(outputs, output_paddings, dropout_rate)
- for idx in range(self.config.num_lstm_layers):
- if self.config.enable_residual_connections:
- outputs = outputs + self.lstms[idx](outputs, output_paddings)
- else:
- outputs = self.lstms[idx](outputs, output_paddings)
-
- for idx in range(self.config.num_ffn_layers):
- if self.config.enable_residual_connections:
- outputs = outputs + self.ffns[idx](outputs, output_paddings, dropout_rate)
- else:
- outputs = self.ffns[idx](outputs, output_paddings, dropout_rate)
-
- if self.config.enable_decoder_layer_norm:
- outputs = self.ln(outputs)
-
- outputs = self.lin(outputs)
-
- return outputs, output_paddings
diff --git a/algoperf/workloads/ogbg/ogbg_pytorch/models.py b/algoperf/workloads/ogbg/ogbg_pytorch/models.py
index fe9b29bc1..be5882333 100644
--- a/algoperf/workloads/ogbg/ogbg_pytorch/models.py
+++ b/algoperf/workloads/ogbg/ogbg_pytorch/models.py
@@ -9,17 +9,20 @@
from torch import nn
from algoperf import init_utils
+from algoperf.pytorch_utils import CustomDropout, SequentialWithDropout
+DROPOUT_RATE = 0.1
-def _make_mlp(in_dim, hidden_dims, dropout_rate, activation_fn):
+
+def _make_mlp(in_dim, hidden_dims, activation_fn):
"""Creates a MLP with specified dimensions."""
- layers = nn.Sequential()
+ layers = SequentialWithDropout()
for i, dim in enumerate(hidden_dims):
layers.add_module(f'dense_{i}',
nn.Linear(in_features=in_dim, out_features=dim))
layers.add_module(f'norm_{i}', nn.LayerNorm(dim, eps=1e-6))
layers.add_module(f'activation_fn_{i}', activation_fn())
- layers.add_module(f'dropout_{i}', nn.Dropout(dropout_rate))
+ layers.add_module(f'dropout_{i}', CustomDropout())
in_dim = dim
return layers
@@ -33,7 +36,6 @@ class GNN(nn.Module):
def __init__(self,
num_outputs: int = 128,
- dropout_rate: Optional[float] = 0.1,
activation_fn_name: str = 'relu',
latent_dim: int = 256,
hidden_dims: Tuple[int] = (256,),
@@ -43,8 +45,6 @@ def __init__(self,
self.hidden_dims = hidden_dims
self.num_message_passing_steps = num_message_passing_steps
self.num_outputs = num_outputs
- if dropout_rate is None:
- dropout_rate = 0.1
# in_features are specifically chosen for the ogbg workload.
self.node_embedder = nn.Linear(in_features=9, out_features=self.latent_dim)
self.edge_embedder = nn.Linear(in_features=3, out_features=self.latent_dim)
@@ -77,17 +77,14 @@ def __init__(self,
GraphNetwork(
update_edge_fn=_make_mlp(in_dim_edge_fn,
self.hidden_dims,
- dropout_rate,
activation_fn),
update_node_fn=_make_mlp(in_dim_node_fn,
self.hidden_dims,
- dropout_rate,
activation_fn),
update_global_fn=_make_mlp(last_in_dim,
self.hidden_dims,
- dropout_rate,
activation_fn)))
- self.graph_network = nn.Sequential(*graph_network_layers)
+ self.graph_network = SequentialWithDropout(*graph_network_layers)
self.decoder = nn.Linear(
in_features=self.hidden_dims[-1], out_features=self.num_outputs)
@@ -96,14 +93,18 @@ def __init__(self,
if isinstance(m, nn.Linear):
init_utils.pytorch_default_init(m)
- def forward(self, graph: GraphsTuple) -> torch.Tensor:
+ def forward(
+ self,
+ graph: GraphsTuple,
+ dropout_rate: float = DROPOUT_RATE) -> torch.Tensor:
+
graph = graph._replace(
globals=torch.zeros([graph.n_node.shape[0], self.num_outputs],
device=graph.n_node.device))
graph = graph._replace(nodes=self.node_embedder(graph.nodes))
graph = graph._replace(edges=self.edge_embedder(graph.edges))
- graph = self.graph_network(graph)
+ graph = self.graph_network(graph, dropout_rate)
# Map globals to represent the final result
graph = graph._replace(globals=self.decoder(graph.globals))
@@ -145,8 +146,9 @@ def __init__(self,
self.update_edge_fn = update_edge_fn
self.update_node_fn = update_node_fn
self.update_global_fn = update_global_fn
+ self._supports_custom_dropout = True # supports SequentialWithDropout
- def forward(self, graph: GraphsTuple) -> GraphsTuple:
+ def forward(self, graph: GraphsTuple, dropout_rate: float) -> GraphsTuple:
"""Applies a configured GraphNetwork to a graph.
This implementation follows Algorithm 1 in https://arxiv.org/abs/1806.01261
There is one difference. For the nodes update the class aggregates over the
@@ -159,6 +161,7 @@ def forward(self, graph: GraphsTuple) -> GraphsTuple:
GraphNets, for more information please see the paper.
Args:
graph: a `GraphsTuple` containing the graph.
+ dropout_rate: dropout probability value.
Returns:
Updated `GraphsTuple`.
"""
@@ -179,7 +182,7 @@ def forward(self, graph: GraphsTuple) -> GraphsTuple:
edge_fn_inputs = torch.cat(
[edges, sent_attributes, received_attributes, global_edge_attributes],
dim=-1)
- edges = self.update_edge_fn(edge_fn_inputs)
+ edges = self.update_edge_fn(edge_fn_inputs, dropout_rate)
if self.update_node_fn:
sent_attributes = tree.tree_map(
@@ -194,7 +197,7 @@ def forward(self, graph: GraphsTuple) -> GraphsTuple:
node_fn_inputs = torch.cat(
[nodes, sent_attributes, received_attributes, global_attributes],
dim=-1)
- nodes = self.update_node_fn(node_fn_inputs)
+ nodes = self.update_node_fn(node_fn_inputs, dropout_rate)
if self.update_global_fn:
n_graph = n_node.shape[0]
@@ -213,7 +216,7 @@ def forward(self, graph: GraphsTuple) -> GraphsTuple:
# These pooled nodes are the inputs to the global update fn.
global_fn_inputs = torch.cat([node_attributes, edge_attributes, globals_],
dim=-1)
- globals_ = self.update_global_fn(global_fn_inputs)
+ globals_ = self.update_global_fn(global_fn_inputs, dropout_rate)
return GraphsTuple(
nodes=nodes,
diff --git a/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py b/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py
deleted file mode 100644
index be5882333..000000000
--- a/algoperf/workloads/ogbg/ogbg_pytorch/models_dropout.py
+++ /dev/null
@@ -1,315 +0,0 @@
-# Ported to PyTorch from
-# https://github.com/google/init2winit/blob/master/init2winit/model_lib/gnn.py.
-from functools import partial
-from typing import Callable, Optional, Tuple
-
-import jax.tree_util as tree
-from jraph import GraphsTuple
-import torch
-from torch import nn
-
-from algoperf import init_utils
-from algoperf.pytorch_utils import CustomDropout, SequentialWithDropout
-
-DROPOUT_RATE = 0.1
-
-
-def _make_mlp(in_dim, hidden_dims, activation_fn):
- """Creates a MLP with specified dimensions."""
- layers = SequentialWithDropout()
- for i, dim in enumerate(hidden_dims):
- layers.add_module(f'dense_{i}',
- nn.Linear(in_features=in_dim, out_features=dim))
- layers.add_module(f'norm_{i}', nn.LayerNorm(dim, eps=1e-6))
- layers.add_module(f'activation_fn_{i}', activation_fn())
- layers.add_module(f'dropout_{i}', CustomDropout())
- in_dim = dim
- return layers
-
-
-class GNN(nn.Module):
- """Defines a graph network.
-
- The model assumes the input data is a jraph.GraphsTuple without global
- variables. The final prediction will be encoded in the globals.
- """
-
- def __init__(self,
- num_outputs: int = 128,
- activation_fn_name: str = 'relu',
- latent_dim: int = 256,
- hidden_dims: Tuple[int] = (256,),
- num_message_passing_steps: int = 5) -> None:
- super().__init__()
- self.latent_dim = latent_dim
- self.hidden_dims = hidden_dims
- self.num_message_passing_steps = num_message_passing_steps
- self.num_outputs = num_outputs
- # in_features are specifically chosen for the ogbg workload.
- self.node_embedder = nn.Linear(in_features=9, out_features=self.latent_dim)
- self.edge_embedder = nn.Linear(in_features=3, out_features=self.latent_dim)
-
- if activation_fn_name == 'relu':
- activation_fn = nn.ReLU
- elif activation_fn_name == 'gelu':
- activation_fn = partial(nn.GELU, approximate='tanh')
- elif activation_fn_name == 'silu':
- activation_fn = nn.SiLU
- else:
- raise ValueError(
- f'Invalid activation function name: {self.activation_fn_name}')
-
- graph_network_layers = []
- for st in range(self.num_message_passing_steps):
- # Constants in in_dims are based on forward call of GraphNetwork:
- # specifically update_edge_fn update_node_fn and update_global_fn.
- if st == 0:
- in_dim_edge_fn = self.latent_dim * 3 + self.num_outputs
- in_dim_node_fn = self.latent_dim + self.hidden_dims[
- -1] * 2 + self.num_outputs
- last_in_dim = self.hidden_dims[-1] * 2 + self.num_outputs
- else:
- in_dim_edge_fn = self.hidden_dims[-1] * 4
- in_dim_node_fn = self.hidden_dims[-1] * 4
- last_in_dim = self.hidden_dims[-1] * 3
-
- graph_network_layers.append(
- GraphNetwork(
- update_edge_fn=_make_mlp(in_dim_edge_fn,
- self.hidden_dims,
- activation_fn),
- update_node_fn=_make_mlp(in_dim_node_fn,
- self.hidden_dims,
- activation_fn),
- update_global_fn=_make_mlp(last_in_dim,
- self.hidden_dims,
- activation_fn)))
- self.graph_network = SequentialWithDropout(*graph_network_layers)
-
- self.decoder = nn.Linear(
- in_features=self.hidden_dims[-1], out_features=self.num_outputs)
-
- for m in self.modules():
- if isinstance(m, nn.Linear):
- init_utils.pytorch_default_init(m)
-
- def forward(
- self,
- graph: GraphsTuple,
- dropout_rate: float = DROPOUT_RATE) -> torch.Tensor:
-
- graph = graph._replace(
- globals=torch.zeros([graph.n_node.shape[0], self.num_outputs],
- device=graph.n_node.device))
- graph = graph._replace(nodes=self.node_embedder(graph.nodes))
- graph = graph._replace(edges=self.edge_embedder(graph.edges))
-
- graph = self.graph_network(graph, dropout_rate)
-
- # Map globals to represent the final result
- graph = graph._replace(globals=self.decoder(graph.globals))
-
- return graph.globals
-
-
-class GraphNetwork(nn.Module):
- """Returns a method that applies a configured GraphNetwork.
- This implementation follows Algorithm 1 in https://arxiv.org/abs/1806.01261
- There is one difference. For the nodes update the class aggregates over the
- sender edges and receiver edges separately. This is a bit more general
- than the algorithm described in the paper. The original behaviour can be
- recovered by using only the receiver edge aggregations for the update.
- In addition this implementation supports softmax attention over incoming
- edge features.
- Example usage::
- gn = GraphNetwork(update_edge_function,
- update_node_function, **kwargs)
- # Conduct multiple rounds of message passing with the same parameters:
- for _ in range(num_message_passing_steps):
- graph = gn(graph)
- Args:
- update_edge_fn: function used to update the edges or None to deactivate edge
- updates.
- update_node_fn: function used to update the nodes or None to deactivate node
- updates.
- update_global_fn: function used to update the globals or None to deactivate
- globals updates.
- Returns:
- A method that applies the configured GraphNetwork.
- """
-
- def __init__(self,
- update_edge_fn: Optional[Callable] = None,
- update_node_fn: Optional[Callable] = None,
- update_global_fn: Optional[Callable] = None) -> None:
- super().__init__()
- self.update_edge_fn = update_edge_fn
- self.update_node_fn = update_node_fn
- self.update_global_fn = update_global_fn
- self._supports_custom_dropout = True # supports SequentialWithDropout
-
- def forward(self, graph: GraphsTuple, dropout_rate: float) -> GraphsTuple:
- """Applies a configured GraphNetwork to a graph.
- This implementation follows Algorithm 1 in https://arxiv.org/abs/1806.01261
- There is one difference. For the nodes update the class aggregates over the
- sender edges and receiver edges separately. This is a bit more general
- the algorithm described in the paper. The original behaviour can be
- recovered by using only the receiver edge aggregations for the update.
- In addition this implementation supports softmax attention over incoming
- edge features.
- Many popular Graph Neural Networks can be implemented as special cases of
- GraphNets, for more information please see the paper.
- Args:
- graph: a `GraphsTuple` containing the graph.
- dropout_rate: dropout probability value.
- Returns:
- Updated `GraphsTuple`.
- """
- nodes, edges, receivers, senders, globals_, n_node, n_edge = graph
- sum_n_node = tree.tree_leaves(nodes)[0].shape[0]
- if not tree.tree_all(
- tree.tree_map(lambda n: n.shape[0] == sum_n_node, nodes)):
- raise ValueError(
- 'All node arrays in nest must contain the same number of nodes.')
-
- sent_attributes = tree.tree_map(lambda n: n[senders], nodes)
- received_attributes = tree.tree_map(lambda n: n[receivers], nodes)
- # Here we scatter the global features to the corresponding edges,
- # giving us tensors of shape [num_edges, global_feat].
- global_edge_attributes = tree.tree_map(
- lambda g: torch.repeat_interleave(g, n_edge, dim=0), globals_)
- if self.update_edge_fn:
- edge_fn_inputs = torch.cat(
- [edges, sent_attributes, received_attributes, global_edge_attributes],
- dim=-1)
- edges = self.update_edge_fn(edge_fn_inputs, dropout_rate)
-
- if self.update_node_fn:
- sent_attributes = tree.tree_map(
- lambda e: scatter_sum(e, senders, dim=0, dim_size=sum_n_node), edges)
- received_attributes = tree.tree_map(
- lambda e: scatter_sum(e, receivers, dim=0, dim_size=sum_n_node),
- edges)
- # Here we scatter the global features to the corresponding nodes,
- # giving us tensors of shape [num_nodes, global_feat].
- global_attributes = tree.tree_map(
- lambda g: torch.repeat_interleave(g, n_node, dim=0), globals_)
- node_fn_inputs = torch.cat(
- [nodes, sent_attributes, received_attributes, global_attributes],
- dim=-1)
- nodes = self.update_node_fn(node_fn_inputs, dropout_rate)
-
- if self.update_global_fn:
- n_graph = n_node.shape[0]
- graph_idx = torch.arange(n_graph, device=graph.n_node.device)
- # To aggregate nodes and edges from each graph to global features,
- # we first construct tensors that map the node to the corresponding graph.
- # For example, if you have `n_node=[1,2]`, we construct the tensor
- # [0, 1, 1]. We then do the same for edges.
- node_gr_idx = torch.repeat_interleave(graph_idx, n_node, dim=0)
- edge_gr_idx = torch.repeat_interleave(graph_idx, n_edge, dim=0)
- # We use the aggregation function to pool the nodes/edges per graph.
- node_attributes = tree.tree_map(
- lambda n: scatter_sum(n, node_gr_idx, dim=0, dim_size=n_graph), nodes)
- edge_attributes = tree.tree_map(
- lambda e: scatter_sum(e, edge_gr_idx, dim=0, dim_size=n_graph), edges)
- # These pooled nodes are the inputs to the global update fn.
- global_fn_inputs = torch.cat([node_attributes, edge_attributes, globals_],
- dim=-1)
- globals_ = self.update_global_fn(global_fn_inputs, dropout_rate)
-
- return GraphsTuple(
- nodes=nodes,
- edges=edges,
- receivers=receivers,
- senders=senders,
- globals=globals_,
- n_node=n_node,
- n_edge=n_edge)
-
-
-# Forked from
-# github.com/rusty1s/pytorch_scatter/blob/master/torch_scatter/scatter.py.
-def scatter_sum(src: torch.Tensor,
- index: torch.Tensor,
- dim: int = -1,
- out: Optional[torch.Tensor] = None,
- dim_size: Optional[int] = None) -> torch.Tensor:
- r"""
- |
- .. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
- master/docs/source/_figures/add.svg?sanitize=true
- :align: center
- :width: 400px
- |
- Reduces all values from the :attr:`src` tensor into :attr:`out` at the
- indices specified in the :attr:`index` tensor along a given axis
- :attr:`dim`.
- For each value in :attr:`src`, its output index is specified by its index
- in :attr:`src` for dimensions outside of :attr:`dim` and by the
- corresponding value in :attr:`index` for dimension :attr:`dim`.
- The applied reduction is here defined as a sum.
- Formally, if :attr:`src` and :attr:`index` are :math:`n`-dimensional
- tensors with size :math:`(x_0, ..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})`
- and :attr:`dim` = `i`, then :attr:`out` must be an :math:`n`-dimensional
- tensor with size :math:`(x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})`.
- Moreover, the values of :attr:`index` must be between :math:`0` and
- :math:`y - 1`, although no specific ordering of indices is required.
- The :attr:`index` tensor supports broadcasting in case its dimensions do
- not match with :attr:`src`.
- For one-dimensional tensors, the operation computes
- .. math::
- \mathrm{out}_i = \mathrm{out}_i + \sum_j~\mathrm{src}_j
- where :math:`\sum_j` is over :math:`j` such that
- :math:`\mathrm{index}_j = i`.
- .. note::
- This operation is implemented via atomic operations on the GPU and is
- therefore **non-deterministic** since the order of parallel operations
- to the same value is undetermined.
- For floating-point variables, this results in a source of variance in
- the result.
- :param src: The source tensor.
- :param index: The indices of elements to scatter.
- :param dim: The axis along which to index. (default: :obj:`-1`)
- :param out: The destination tensor.
- :param dim_size: If :attr:`out` is not given, automatically create output
- with size :attr:`dim_size` at dimension :attr:`dim`.
- If :attr:`dim_size` is not given, a minimal sized output tensor
- according to :obj:`index.max() + 1` is returned.
- :rtype: :class:`Tensor`
- .. code-block:: python
- src = torch.randn(10, 6, 64)
- index = torch.tensor([0, 1, 0, 1, 2, 1])
- # Broadcasting in the first and last dim.
- out = scatter_sum(src, index, dim=1)
- print(out.size())
- .. code-block::
- torch.Size([10, 3, 64])
- """
- index = broadcast(index, src, dim)
- if out is None:
- size = list(src.size())
- if dim_size is not None:
- size[dim] = dim_size
- elif index.numel() == 0:
- size[dim] = 0
- else:
- size[dim] = int(index.max()) + 1
- out = torch.zeros(size, dtype=src.dtype, device=src.device)
- return out.scatter_add_(dim, index, src)
- else:
- return out.scatter_add_(dim, index, src)
-
-
-# Forked from
-# github.com/rusty1s/pytorch_scatter/blob/master/torch_scatter/utils.py.
-def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int):
- if dim < 0:
- dim = other.dim() + dim
- if src.dim() == 1:
- for _ in range(0, dim):
- src = src.unsqueeze(0)
- for _ in range(src.dim(), other.dim()):
- src = src.unsqueeze(-1)
- src = src.expand(other.size())
- return src
diff --git a/algoperf/workloads/wmt/wmt_pytorch/models.py b/algoperf/workloads/wmt/wmt_pytorch/models.py
index a1c7ce15e..a43df30d4 100644
--- a/algoperf/workloads/wmt/wmt_pytorch/models.py
+++ b/algoperf/workloads/wmt/wmt_pytorch/models.py
@@ -9,6 +9,8 @@
from torch.nn.init import normal_
from torch.nn.init import xavier_uniform_
+DROPOUT_RATE = 0.1
+
def make_causal_mask(x: Tensor, device: str = 'cuda:0') -> Tensor:
"""Make a causal mask for self-attention.
@@ -104,26 +106,18 @@ def __init__(self,
nhead: int = 16,
d_hid: int = 1024,
nlayers: int = 6,
- dropout_rate: Optional[float] = 0.1,
- attention_dropout_rate: Optional[float] = 0.1,
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
glu: bool = False,
layer_norm_eps: float = 1e-6,
attention_temp: float = 1.0,
pre_ln: bool = True):
super().__init__()
- if dropout_rate is None:
- dropout_rate = 0.1
- if attention_dropout_rate is None:
- attention_dropout_rate = 0.1
- self.pos_encoder = PositionalEncoding(d_model, dropout_rate)
+ self.pos_encoder = PositionalEncoding(d_model)
self.shared_embedding = nn.Embedding(ntoken, d_model)
self.encoder = Encoder(d_model,
nhead,
d_hid,
nlayers,
- dropout_rate,
- attention_dropout_rate,
activation,
glu,
layer_norm_eps,
@@ -133,8 +127,6 @@ def __init__(self,
nhead,
d_hid,
nlayers,
- dropout_rate,
- attention_dropout_rate,
activation,
glu,
layer_norm_eps,
@@ -163,7 +155,8 @@ def forward(self,
targets_positions: Optional[Tensor] = None,
inputs_segmentation: Optional[Tensor] = None,
targets_segmentation: Optional[Tensor] = None,
- decode: bool = False) -> Tensor:
+ decode: bool = False,
+ dropout_rate: float = DROPOUT_RATE) -> Tensor:
"""
Args:
src: Tensor, shape [batch_size, seq_len]
@@ -173,16 +166,19 @@ def forward(self,
inputs_segmentation: Optional[Tensor], shape [batch_size, seq_len]
targets_segmentation: Optional[Tensor], shape [batch_size, seq_len]
decode: bool
+ dropout_rate: float
Returns:
output Tensor of shape [batch_size, seq_len, ntoken]
"""
if src.size(0) != tgt.size(0):
raise RuntimeError('The batch size of src and tgt must be equal.')
+
memory = self.encoder(
src,
inputs_positions=inputs_positions,
- inputs_segmentation=inputs_segmentation)
+ inputs_segmentation=inputs_segmentation,
+ dropout_rate=dropout_rate)
output = self.decoder(
tgt,
memory,
@@ -190,7 +186,8 @@ def forward(self,
targets_positions=targets_positions,
inputs_segmentation=inputs_segmentation,
targets_segmentation=targets_segmentation,
- decode=decode)
+ decode=decode,
+ dropout_rate=dropout_rate)
return output
@@ -229,12 +226,15 @@ def __init__(self,
self.enable_nested_tensor = enable_nested_tensor
self.mask_check = mask_check
- def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor:
+ def forward(self, src: Tensor,
+ mask: Optional[Tensor] = None,
+ dropout_rate: Optional[float] = 0.0) -> Tensor:
"""Pass the input through the encoder layers in turn.
Args:
src: the sequence to the encoder (required).
mask: the mask for the src sequence (optional).
+ dropout_rate: the dropout probability (optional).
Shape:
see the docs in Transformer class.
@@ -243,7 +243,7 @@ def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor:
convert_to_nested = False
for mod in self.layers:
- output = mod(output, src_mask=mask)
+ output = mod(output, src_mask=mask, dropout_rate=dropout_rate)
if convert_to_nested:
output = output.to_padded_tensor(0.)
@@ -261,8 +261,6 @@ def __init__(self,
nhead: int = 16,
d_hid: int = 1024,
nlayers: int = 6,
- dropout_rate: float = 0.1,
- attention_dropout_rate: float = 0.1,
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
glu: bool = False,
layer_norm_eps: float = 1e-6,
@@ -276,8 +274,6 @@ def __init__(self,
d_model,
nhead,
d_hid,
- dropout_rate,
- attention_dropout_rate=attention_dropout_rate,
activation=activation,
glu=glu,
layer_norm_eps=layer_norm_eps,
@@ -290,12 +286,13 @@ def __init__(self,
def forward(self,
src: Tensor,
inputs_positions: Optional[Tensor] = None,
- inputs_segmentation: Optional[Tensor] = None) -> Tensor:
+ inputs_segmentation: Optional[Tensor] = None,
+ dropout_rate: Optional[float] = 0.0) -> Tensor:
src = src.to(torch.int)
src_mask = make_src_mask(src, inputs_segmentation, self.nhead)
src = self.shared_embedding(src)
- src = self.pos_encoder(src, inputs_positions)
- memory = self.encoder(src, mask=src_mask)
+ src = self.pos_encoder(src, inputs_positions, dropout_rate=dropout_rate)
+ memory = self.encoder(src, mask=src_mask, dropout_rate=dropout_rate)
return memory
@@ -306,8 +303,6 @@ def __init__(self,
nhead: int = 16,
d_hid: int = 1024,
nlayers: int = 6,
- dropout_rate: float = 0.1,
- attention_dropout_rate: float = 0.1,
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
glu: bool = False,
layer_norm_eps: float = 1e-6,
@@ -320,8 +315,6 @@ def __init__(self,
self.decoder = TransformerDecoder(d_model,
nhead,
d_hid,
- dropout_rate,
- attention_dropout_rate,
activation,
glu,
layer_norm_eps,
@@ -339,7 +332,8 @@ def forward(
targets_segmentation: Optional[Tensor] = None,
decode: bool = False,
max_len: Optional[int] = None,
- cache: Optional[dict] = None) -> Any:
+ cache: Optional[dict] = None,
+ dropout_rate: Optional[float] = 0.0) -> Any:
tgt = tgt.to(torch.int)
tgt_mask, memory_mask = make_tgt_and_memory_mask(
tgt, src, inputs_segmentation, targets_segmentation,
@@ -347,7 +341,7 @@ def forward(
if not decode:
tgt = shift_right(tgt)
tgt = self.shared_embedding(tgt)
- tgt = self.pos_encoder(tgt, targets_positions, decode=decode, cache=cache)
+ tgt = self.pos_encoder(tgt, targets_positions, decode=decode, cache=cache, dropout_rate=dropout_rate)
if decode:
tgt, cache = tgt
output = self.decoder(
@@ -357,7 +351,8 @@ def forward(
memory_mask=memory_mask,
decode=decode,
max_len=max_len,
- cache=cache)
+ cache=cache,
+ dropout_rate=dropout_rate)
if decode:
output, cache = output
normalize = math.sqrt(output.shape[-1])
@@ -371,10 +366,8 @@ class PositionalEncoding(nn.Module):
def __init__(self,
d_model: int,
- dropout_rate: float = 0.1,
max_len: int = 256):
super().__init__()
- self.dropout = nn.Dropout(p=dropout_rate)
position = torch.arange(max_len).unsqueeze(1)
scale_factor = -math.log(10000.0) / (d_model // 2 - 1)
@@ -389,7 +382,8 @@ def forward(
x: Tensor,
inputs_positions: Optional[Tensor] = None,
decode: bool = False,
- cache: Optional[Dict[str, Dict[str, Tensor]]] = None
+ cache: Optional[Dict[str, Dict[str, Tensor]]] = None,
+ dropout_rate: Optional[float] = 0.0
) -> Union[Tensor, Tuple[Tensor, Dict[str, Dict[str, Tensor]]]]:
"""
Args:
@@ -397,6 +391,7 @@ def forward(
inputs_positions: Tensor (shape [batch_size, seq_len]) or None
decode: bool
cache: Dict[str, Dict[str, Tensor]] or None
+ dropout_rate: Optional[float]
Returns:
Tensor or Tuple[Tensor, Dict[str, Dict[str, Tensor]]]
"""
@@ -412,14 +407,14 @@ def forward(
}
pe = self.pe[0, cache[name]['cache_index'], :]
cache[name]['cache_index'] += 1
- return self.dropout(x + pe), cache
+ return F.dropout(x + pe, dropout_rate, self.training), cache
if inputs_positions is None:
# normal unpacked case:
pe = self.pe[:, :x.size(1), :]
else:
# for packed data we need to use known position indices:
pe = self.pe[0, inputs_positions, :]
- return self.dropout(x + pe)
+ return F.dropout(x + pe, dropout_rate, self.training)
# TransformerEncoderLayer and TransformerDecoderLayer are taken from:
@@ -438,7 +433,6 @@ class TransformerEncoderLayer(nn.Module):
nhead: the number of heads in the multiheadattention models (default=16).
dim_feedforward: the dimension of the feedforward network model
(default=1024).
- dropout_rate: the dropout_rate value (default=0.1).
activation: the activation function of the intermediate layer, can be a
string ("relu" or "gelu") or a unary callable (default=F.relu).
layer_norm_eps: the eps value in layer normalization components
@@ -457,8 +451,6 @@ def __init__(self,
d_model: int = 1024,
nhead: int = 16,
dim_feedforward: int = 1024,
- dropout_rate: float = 0.1,
- attention_dropout_rate: float = 0.1,
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
glu: bool = False,
layer_norm_eps: float = 1e-6,
@@ -472,7 +464,6 @@ def __init__(self,
d_model,
nhead,
self_attn=True,
- dropout_rate=attention_dropout_rate,
attention_temp=attention_temp,
bias=False,
**factory_kwargs)
@@ -482,50 +473,55 @@ def __init__(self,
self.glu = glu
if self.glu:
self.linear_glu = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
- self.dropout = nn.Dropout(dropout_rate)
self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
self.pre_ln = pre_ln
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
- self.dropout1 = nn.Dropout(dropout_rate)
- self.dropout2 = nn.Dropout(dropout_rate)
self.activation = activation
- def forward(self, src: Tensor, src_mask: Optional[Tensor] = None) -> Tensor:
+ def forward(self,
+ src: Tensor,
+ src_mask: Optional[Tensor] = None,
+ dropout_rate: Optional[float] = 0.0) -> Tensor:
r"""Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
src_mask: the mask for the src sequence (optional).
-
+ dropout_rate: the dropout probability value (optional).
Shape:
see the docs in Transformer class.
"""
x = src
if self.pre_ln:
- x = x + self._sa_block(self.norm1(x), src_mask)
- x = x + self._ff_block(self.norm2(x))
+ x = x + self._sa_block(self.norm1(x), src_mask, dropout_rate)
+ x = x + self._ff_block(self.norm2(x), dropout_rate)
else:
- x = self.norm1(x + self._sa_block(x, src_mask))
- x = self.norm2(x + self._ff_block(x))
+ x = self.norm1(x + self._sa_block(x, src_mask, dropout_rate))
+ x = self.norm2(x + self._ff_block(x, dropout_rate))
return x
# Self-attention block:
- def _sa_block(self, x: Tensor, attn_mask: Optional[Tensor]) -> Tensor:
- x, _ = self.self_attn(x, attn_mask=attn_mask)
- return self.dropout1(x)
+ def _sa_block(self,
+ x: Tensor,
+ attn_mask: Optional[Tensor],
+ dropout_rate: Optional[float] = 0.0) -> Tensor:
+ x, _ = self.self_attn(x, attn_mask=attn_mask, dropout_rate=dropout_rate)
+ return F.dropout(x, dropout_rate, training=self.training)
# Feed forward block:
- def _ff_block(self, inputs: Tensor) -> Tensor:
+ def _ff_block(self,
+ inputs: Tensor,
+ dropout_rate: Optional[float] = 0.0) -> Tensor:
x = self.activation(self.linear1(inputs))
if self.glu:
y = self.linear_glu(inputs)
x = x * y
- x = self.linear2(self.dropout(x))
- return self.dropout2(x)
+ x = self.linear2(F.dropout(x, dropout_rate, training=self.training))
+ return F.dropout(x, dropout_rate, training=self.training)
# Modified to use cache for autoregressive decoding and custom
@@ -537,7 +533,6 @@ class TransformerDecoder(nn.Module):
nhead: the number of heads in the multiheadattention models (default=16)
d_hid: the dimension of the feedforward network model
(default=1024)
- dropout_rate: the dropout_rate value (default=0.1)
layer_norm_eps: the eps value in layer normalization components
(default=1e-6).
decoder_layer: an instance of the TransformerDecoderLayer() class
@@ -555,8 +550,6 @@ def __init__(self,
d_model,
nhead,
d_hid,
- dropout_rate,
- attention_dropout_rate,
activation,
glu,
layer_norm_eps,
@@ -569,8 +562,6 @@ def __init__(self,
d_model,
nhead,
d_hid,
- dropout_rate,
- attention_dropout_rate,
activation,
glu,
layer_norm_eps=layer_norm_eps,
@@ -587,7 +578,8 @@ def forward(self,
memory_mask: Optional[Tensor] = None,
decode: bool = False,
max_len: Optional[int] = None,
- cache: Optional[dict] = None) -> Any:
+ cache: Optional[dict] = None,
+ dropout_rate: Optional[float] = 0.0) -> Any:
r"""Pass the inputs (and mask) through the decoder layer in turn.
Args:
tgt: the sequence to the decoder (required).
@@ -596,6 +588,7 @@ def forward(self,
memory_mask: the mask for the memory sequence (optional).
decode: whether to use cache for autoregressive decoding or not.
max_len: maximum sequence length, necessary for decoding cache.
+ dropout_rate: the dropout probability value (optional)
Shape:
see the docs in Transformer class.
"""
@@ -610,7 +603,8 @@ def forward(self,
decode=decode,
max_len=max_len,
cache=cache,
- index=idx)
+ index=idx,
+ dropout_rate=dropout_rate)
if self.norm is not None:
output = self.norm(output)
@@ -636,7 +630,6 @@ class TransformerDecoderLayer(nn.Module):
nhead: the number of heads in the multiheadattention models (default=16).
dim_feedforward: the dimension of the feedforward network model
(default=1024).
- dropout_rate: the dropout_rate value (default=0.1).
activation: the activation function of the intermediate layer, can be a
string ("relu" or "gelu") or a unary callable (default=F.relu).
layer_norm_eps: the eps value in layer normalization components
@@ -656,8 +649,6 @@ def __init__(self,
d_model: int = 1024,
nhead: int = 16,
dim_feedforward: int = 1024,
- dropout_rate: float = 0.1,
- attention_dropout_rate: float = 0.1,
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
glu: bool = False,
layer_norm_eps: float = 1e-6,
@@ -671,7 +662,6 @@ def __init__(self,
d_model,
nhead,
self_attn=True,
- dropout_rate=attention_dropout_rate,
attention_temp=attention_temp,
bias=False,
**factory_kwargs)
@@ -679,7 +669,6 @@ def __init__(self,
d_model,
nhead,
self_attn=False,
- dropout_rate=attention_dropout_rate,
attention_temp=attention_temp,
bias=False,
**factory_kwargs)
@@ -691,16 +680,12 @@ def __init__(self,
self.linear_glu = nn.Linear(dim_feedforward,
dim_feedforward,
**factory_kwargs)
- self.dropout = nn.Dropout(dropout_rate)
self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
self.pre_ln = pre_ln
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
- self.dropout1 = nn.Dropout(dropout_rate)
- self.dropout2 = nn.Dropout(dropout_rate)
- self.dropout3 = nn.Dropout(dropout_rate)
self.activation = activation
@@ -713,7 +698,8 @@ def forward( # pylint: disable=arguments-renamed
decode: bool = False,
max_len: Optional[int] = None,
cache: Optional[dict] = None,
- index: Optional[int] = None) -> Any:
+ index: Optional[int] = None,
+ dropout_rate: Optional[float] = 0.0) -> Any:
r"""Pass the inputs (and mask) through the decoder layer.
Args:
tgt: the sequence to the decoder layer (required).
@@ -722,6 +708,7 @@ def forward( # pylint: disable=arguments-renamed
memory_mask: the mask for the memory sequence (optional).
decode: wether to use cache for autoregressive decoding or not.
max_len: maximum sequence length, necessary for decoding cache.
+ dropout_rate: the dropout probability value (optional)
Shape:
see the docs in Transformer class.
"""
@@ -735,10 +722,11 @@ def forward( # pylint: disable=arguments-renamed
decode=decode,
max_len=max_len,
cache=cache,
- index=index)
+ index=index,
+ dropout_rate=dropout_rate)
x = x + sa_out
- x = x + self._mha_block(self.norm2(x), memory, memory_mask)
- x = x + self._ff_block(self.norm3(x))
+ x = x + self._mha_block(self.norm2(x), memory, memory_mask, dropout_rate)
+ x = x + self._ff_block(self.norm3(x), dropout_rate)
else:
sa_out, cache = self._sa_block(
x,
@@ -746,10 +734,11 @@ def forward( # pylint: disable=arguments-renamed
decode=decode,
max_len=max_len,
cache=cache,
- index=index)
+ index=index,
+ dropout_rate=dropout_rate)
x = self.norm1(x + sa_out)
- x = self.norm2(x + self._mha_block(x, memory, memory_mask))
- x = self.norm3(x + self._ff_block(x))
+ x = self.norm2(x + self._mha_block(x, memory, memory_mask, dropout_rate))
+ x = self.norm3(x + self._ff_block(x, dropout_rate))
return x, cache
@@ -761,30 +750,38 @@ def _sa_block( # pylint: disable=arguments-renamed
decode: bool = False,
max_len: Optional[int] = None,
cache: Optional[dict] = None,
- index: Optional[int] = None) -> Any:
+ index: Optional[int] = None,
+ dropout_rate: Optional[float] = 0.0) -> Any:
x, cache = self.self_attn(
x,
attn_mask=attn_mask,
decode=decode,
max_len=max_len,
cache=cache,
- index=index)
- return self.dropout1(x), cache
+ index=index,
+ dropout_rate=dropout_rate)
+ return F.dropout(x, dropout_rate, self.training), cache
# Multihead attention block:
def _mha_block(self, x: Tensor, mem: Tensor,
- attn_mask: Optional[Tensor]) -> Tensor:
- x, _ = self.multihead_attn(x, mem, attn_mask=attn_mask)
- return self.dropout2(x)
+ attn_mask: Optional[Tensor],
+ dropout_rate: Optional[float] = 0.0) -> Tensor:
+ x, _ = self.multihead_attn(
+ x,
+ mem,
+ attn_mask=attn_mask,
+ dropout_rate=dropout_rate)
+ return F.dropout(x, dropout_rate, self.training)
# Feed forward block.
- def _ff_block(self, inputs: Tensor) -> Tensor:
+ def _ff_block(self, inputs: Tensor,
+ dropout_rate: Optional[float] = 0.0) -> Tensor:
x = self.activation(self.linear1(inputs))
if self.glu:
y = self.linear_glu(inputs)
x = x * y
- x = self.linear2(self.dropout(x))
- return self.dropout3(x)
+ x = self.linear2(F.dropout(x, dropout_rate, self.training))
+ return F.dropout(x, dropout_rate, self.training)
class MultiheadAttention(nn.Module):
@@ -802,8 +799,6 @@ class MultiheadAttention(nn.Module):
``embed_dim // num_heads``).
self_attn: Whether self attention or encoder-decoder attention is used.
Default: ``True``.
- dropout_rate: Dropout probability on ``attn_output_weights``.
- Default: ``0.0`` (no dropout_rate).
bias: If specified, adds bias to input / output projection layers.
Default: ``False``.
device: The device of the module.
@@ -817,7 +812,6 @@ def __init__(self,
embed_dim: int,
num_heads: int,
self_attn: bool = True,
- dropout_rate: float = 0.,
attention_temp: float = 1.0,
bias: bool = False,
device: Optional[torch.device] = None,
@@ -826,7 +820,6 @@ def __init__(self,
self.embed_dim = embed_dim
self.num_heads = num_heads
self.self_attn = self_attn
- self.dropout = dropout_rate
self.head_dim = embed_dim // num_heads
self.attention_temp = attention_temp
assert self.head_dim * num_heads == self.embed_dim, \
@@ -861,7 +854,8 @@ def forward(self,
decode: bool = False,
max_len: Optional[int] = None,
cache: Optional[dict] = None,
- index: Optional[int] = None) -> Any:
+ index: Optional[int] = None,
+ dropout_rate: Optional[float] = 0.0) -> Any: # TODO: (nico) remove default?!
r"""
Args:
x: Batch of input sequences of shape
@@ -887,6 +881,7 @@ def forward(self,
max_len: maximum sequence length, necessary for decoding cache.
cache: cache dictionary for autoregressive decoding.
index: index of the current decoding step, necessary for decoding cache.
+ dropout_rate: dropout probability on ``attn_output_weights``.
Outputs:
- **attn_output** - Attention outputs of shape :math:`(N, L, E)`, where
:math:`L` is the target sequence length, :math:`N` is the batch size,
@@ -976,12 +971,12 @@ def forward(self,
attn_mask = new_attn_mask
# Adjust dropout_rate probability.
- dropout_rate = self.dropout if self.training else 0.0
+ attn_dropout_rate = dropout_rate if self.training else 0.0
# Calculate attention.
q = self.attention_temp * q
attn_output = torch.nn.functional.scaled_dot_product_attention(
- q, k, v, attn_mask, dropout_rate)
+ q, k, v, attn_mask, attn_dropout_rate)
# Rearrange for output projection.
attn_output = attn_output.transpose(1, 2).contiguous().view(
bsz, tgt_len, embed_dim)
diff --git a/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py b/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py
deleted file mode 100644
index a43df30d4..000000000
--- a/algoperf/workloads/wmt/wmt_pytorch/models_dropout.py
+++ /dev/null
@@ -1,989 +0,0 @@
-import copy
-import math
-from typing import Any, Callable, Dict, Optional, Tuple, Union
-
-import torch
-from torch import nn
-from torch import Tensor
-import torch.nn.functional as F
-from torch.nn.init import normal_
-from torch.nn.init import xavier_uniform_
-
-DROPOUT_RATE = 0.1
-
-
-def make_causal_mask(x: Tensor, device: str = 'cuda:0') -> Tensor:
- """Make a causal mask for self-attention.
-
- Args:
- x: input array of shape `[batch..., len]`
- device: device to store the idxs
-
- Returns:
- A `[batch..., len, len]` shaped causal attention mask.
- """
- idxs = torch.broadcast_to(
- torch.arange(x.shape[-1], dtype=torch.int32, device=device), x.shape)
- return torch.greater_equal(idxs.unsqueeze(-1), idxs.unsqueeze(-2))
-
-
-def make_src_mask(src, inputs_segmentation, nhead):
- """Utility for creating src mask and adjust it for PyTorch Transformer API."""
- src_mask = torch.mul((src > 0).unsqueeze(-1), (src > 0).unsqueeze(-2))
- # Add segmentation block-diagonal attention mask if using segmented data.
- if inputs_segmentation is not None:
- src_mask = torch.logical_and(
- src_mask,
- torch.eq(
- inputs_segmentation.unsqueeze(-1),
- inputs_segmentation.unsqueeze(-2)))
- # Flip values and ensure numerical stability.
- src_mask = torch.repeat_interleave(
- torch.logical_not(src_mask), repeats=nhead, dim=0)
- new_src_mask = torch.zeros_like(src_mask, dtype=torch.float32)
- new_src_mask.masked_fill_(src_mask, -1e10)
- return new_src_mask
-
-
-def make_tgt_and_memory_mask(tgt,
- src,
- inputs_segmentation,
- targets_segmentation,
- decode,
- nhead):
- """ Utility for creating target and memory mask and adjust them for PyTorch
- Transformer API."""
- if not decode:
- tgt_mask = torch.logical_and(
- torch.mul((tgt > 0).unsqueeze(-1), (tgt > 0).unsqueeze(-2)),
- make_causal_mask(tgt, device=tgt.device))
- memory_mask = torch.mul((tgt > 0).unsqueeze(-1), (src > 0).unsqueeze(-2))
- else:
- tgt_mask = None
- memory_mask = torch.mul((torch.ones_like(tgt) > 0).unsqueeze(-1),
- (src > 0).unsqueeze(-2))
- # Add segmentation block-diagonal attention masks if using segmented data.
- if inputs_segmentation is not None:
- tgt_mask = torch.logical_and(
- tgt_mask,
- torch.eq(
- targets_segmentation.unsqueeze(-1),
- targets_segmentation.unsqueeze(-2)))
- memory_mask = torch.logical_and(
- memory_mask,
- torch.eq(
- targets_segmentation.unsqueeze(-1),
- inputs_segmentation.unsqueeze(-2)))
- # Flip values and ensure numerical stability.
- memory_mask = torch.repeat_interleave(
- torch.logical_not(memory_mask), repeats=nhead, dim=0)
- new_memory_mask = torch.zeros_like(memory_mask, dtype=torch.float32)
- new_memory_mask.masked_fill_(memory_mask, -1e10)
- if tgt_mask is not None:
- tgt_mask = torch.repeat_interleave(
- torch.logical_not(tgt_mask), repeats=nhead, dim=0)
- new_tgt_mask = torch.zeros_like(tgt_mask, dtype=torch.float32)
- new_tgt_mask.masked_fill_(tgt_mask, -1e10)
- tgt_mask = new_tgt_mask
- return tgt_mask, new_memory_mask
-
-
-def shift_right(x, axis=1):
- """Shift the input to the right by padding on axis 1."""
- pad_widths = [(0, 0)] * len(x.shape)
- pad_widths[axis] = (1, 0)
- pad_widths = tuple(t for tup in reversed(pad_widths) for t in tup)
- padded = F.pad(x, pad_widths, mode='constant')
- return padded[:, :-1]
-
-
-class Transformer(nn.Module):
- """Transformer architecture based on the model from the WMT Jax workload."""
-
- def __init__(self,
- ntoken: int = 32000,
- d_model: int = 1024,
- nhead: int = 16,
- d_hid: int = 1024,
- nlayers: int = 6,
- activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
- glu: bool = False,
- layer_norm_eps: float = 1e-6,
- attention_temp: float = 1.0,
- pre_ln: bool = True):
- super().__init__()
- self.pos_encoder = PositionalEncoding(d_model)
- self.shared_embedding = nn.Embedding(ntoken, d_model)
- self.encoder = Encoder(d_model,
- nhead,
- d_hid,
- nlayers,
- activation,
- glu,
- layer_norm_eps,
- attention_temp,
- pre_ln)
- self.decoder = Decoder(d_model,
- nhead,
- d_hid,
- nlayers,
- activation,
- glu,
- layer_norm_eps,
- attention_temp,
- pre_ln)
- # Share positional encoding and embedding between encoder and decoder.
- self.encoder.pos_encoder = self.pos_encoder
- self.encoder.shared_embedding = self.shared_embedding
- self.decoder.pos_encoder = self.pos_encoder
- self.decoder.shared_embedding = self.shared_embedding
-
- self._reset_parameters()
-
- def _reset_parameters(self):
- """Initiate parameters in the transformer model."""
- for module in self.modules():
- if isinstance(module, nn.Linear):
- xavier_uniform_(module.weight)
- if module.bias is not None:
- normal_(module.bias, std=1e-6)
-
- def forward(self,
- src: Tensor,
- tgt: Tensor,
- inputs_positions: Optional[Tensor] = None,
- targets_positions: Optional[Tensor] = None,
- inputs_segmentation: Optional[Tensor] = None,
- targets_segmentation: Optional[Tensor] = None,
- decode: bool = False,
- dropout_rate: float = DROPOUT_RATE) -> Tensor:
- """
- Args:
- src: Tensor, shape [batch_size, seq_len]
- tgt: Tensor, shape [batch_size, seq_len]
- inputs_positions: Optional[Tensor], shape [batch_size, seq_len]
- targets_positions: Optional[Tensor], shape [batch_size, seq_len]
- inputs_segmentation: Optional[Tensor], shape [batch_size, seq_len]
- targets_segmentation: Optional[Tensor], shape [batch_size, seq_len]
- decode: bool
- dropout_rate: float
-
- Returns:
- output Tensor of shape [batch_size, seq_len, ntoken]
- """
- if src.size(0) != tgt.size(0):
- raise RuntimeError('The batch size of src and tgt must be equal.')
-
- memory = self.encoder(
- src,
- inputs_positions=inputs_positions,
- inputs_segmentation=inputs_segmentation,
- dropout_rate=dropout_rate)
- output = self.decoder(
- tgt,
- memory,
- src, # just for calculating the padding mask
- targets_positions=targets_positions,
- inputs_segmentation=inputs_segmentation,
- targets_segmentation=targets_segmentation,
- decode=decode,
- dropout_rate=dropout_rate)
- return output
-
-
-class TransformerEncoder(nn.Module):
- r"""TransformerEncoder is a stack of N encoder layers. Users can build the
- BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters.
-
- Args:
- encoder_layer: an instance of the TransformerEncoderLayer() class.
- num_layers: the number of sub-encoder-layers in the encoder.
- norm: the layer normalization component (optional).
- enable_nested_tensor: if True, input will automatically convert to
- nested tensor (and convert back on output). This will improve
- the overall performance of TransformerEncoder when padding
- rate is high.
-
- Examples::
- >>> encoder_layer = nn.TransformerEncoderLayer(12, 8)
- >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, 6)
- >>> src = torch.rand(10, 32, 512)
- >>> out = transformer_encoder(src)
- """
- __constants__ = ['norm']
-
- def __init__(self,
- encoder_layer,
- num_layers,
- norm=None,
- enable_nested_tensor=True,
- mask_check=True):
- super().__init__()
- self.layers = nn.ModuleList(
- [copy.deepcopy(encoder_layer) for _ in range(num_layers)])
- self.num_layers = num_layers
- self.norm = norm
- self.enable_nested_tensor = enable_nested_tensor
- self.mask_check = mask_check
-
- def forward(self, src: Tensor,
- mask: Optional[Tensor] = None,
- dropout_rate: Optional[float] = 0.0) -> Tensor:
- """Pass the input through the encoder layers in turn.
-
- Args:
- src: the sequence to the encoder (required).
- mask: the mask for the src sequence (optional).
- dropout_rate: the dropout probability (optional).
-
- Shape:
- see the docs in Transformer class.
- """
- output = src
- convert_to_nested = False
-
- for mod in self.layers:
- output = mod(output, src_mask=mask, dropout_rate=dropout_rate)
-
- if convert_to_nested:
- output = output.to_padded_tensor(0.)
-
- if self.norm is not None:
- output = self.norm(output)
-
- return output
-
-
-class Encoder(nn.Module):
-
- def __init__(self,
- d_model: int = 1024,
- nhead: int = 16,
- d_hid: int = 1024,
- nlayers: int = 6,
- activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
- glu: bool = False,
- layer_norm_eps: float = 1e-6,
- attention_temp: float = 1.0,
- pre_ln: bool = True):
- super().__init__()
- self.nhead = nhead
- self.shared_embedding = None
- self.pos_encoder = None
- encoder_layer = TransformerEncoderLayer(
- d_model,
- nhead,
- d_hid,
- activation=activation,
- glu=glu,
- layer_norm_eps=layer_norm_eps,
- attention_temp=attention_temp,
- pre_ln=pre_ln)
- encoder_norm = (
- nn.LayerNorm(d_model, eps=layer_norm_eps) if pre_ln else None)
- self.encoder = TransformerEncoder(encoder_layer, nlayers, encoder_norm)
-
- def forward(self,
- src: Tensor,
- inputs_positions: Optional[Tensor] = None,
- inputs_segmentation: Optional[Tensor] = None,
- dropout_rate: Optional[float] = 0.0) -> Tensor:
- src = src.to(torch.int)
- src_mask = make_src_mask(src, inputs_segmentation, self.nhead)
- src = self.shared_embedding(src)
- src = self.pos_encoder(src, inputs_positions, dropout_rate=dropout_rate)
- memory = self.encoder(src, mask=src_mask, dropout_rate=dropout_rate)
- return memory
-
-
-class Decoder(nn.Module):
-
- def __init__(self,
- d_model: int = 1024,
- nhead: int = 16,
- d_hid: int = 1024,
- nlayers: int = 6,
- activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
- glu: bool = False,
- layer_norm_eps: float = 1e-6,
- attention_temp: float = 1.0,
- pre_ln: bool = True):
- super().__init__()
- self.nhead = nhead
- self.shared_embedding = None
- self.pos_encoder = None
- self.decoder = TransformerDecoder(d_model,
- nhead,
- d_hid,
- activation,
- glu,
- layer_norm_eps,
- nlayers,
- attention_temp,
- pre_ln)
-
- def forward(
- self,
- tgt: Tensor,
- memory: Tensor,
- src: Tensor, # just for calculating the padding mask
- targets_positions: Optional[Tensor] = None,
- inputs_segmentation: Optional[Tensor] = None,
- targets_segmentation: Optional[Tensor] = None,
- decode: bool = False,
- max_len: Optional[int] = None,
- cache: Optional[dict] = None,
- dropout_rate: Optional[float] = 0.0) -> Any:
- tgt = tgt.to(torch.int)
- tgt_mask, memory_mask = make_tgt_and_memory_mask(
- tgt, src, inputs_segmentation, targets_segmentation,
- decode, self.nhead)
- if not decode:
- tgt = shift_right(tgt)
- tgt = self.shared_embedding(tgt)
- tgt = self.pos_encoder(tgt, targets_positions, decode=decode, cache=cache, dropout_rate=dropout_rate)
- if decode:
- tgt, cache = tgt
- output = self.decoder(
- tgt,
- memory,
- tgt_mask=tgt_mask,
- memory_mask=memory_mask,
- decode=decode,
- max_len=max_len,
- cache=cache,
- dropout_rate=dropout_rate)
- if decode:
- output, cache = output
- normalize = math.sqrt(output.shape[-1])
- output = torch.matmul(output, self.shared_embedding.weight.T) / normalize
- if decode:
- return output, cache
- return output
-
-
-class PositionalEncoding(nn.Module):
-
- def __init__(self,
- d_model: int,
- max_len: int = 256):
- super().__init__()
-
- position = torch.arange(max_len).unsqueeze(1)
- scale_factor = -math.log(10000.0) / (d_model // 2 - 1)
- div_term = torch.exp(torch.arange(d_model // 2) * scale_factor)
- pe = torch.zeros(1, max_len, d_model)
- pe[0, :, :d_model // 2] = torch.sin(position * div_term)
- pe[0, :, d_model // 2:2 * (d_model // 2)] = torch.cos(position * div_term)
- self.register_buffer('pe', pe)
-
- def forward(
- self,
- x: Tensor,
- inputs_positions: Optional[Tensor] = None,
- decode: bool = False,
- cache: Optional[Dict[str, Dict[str, Tensor]]] = None,
- dropout_rate: Optional[float] = 0.0
- ) -> Union[Tensor, Tuple[Tensor, Dict[str, Dict[str, Tensor]]]]:
- """
- Args:
- x: Tensor (shape [batch_size, seq_len, embedding_dim])
- inputs_positions: Tensor (shape [batch_size, seq_len]) or None
- decode: bool
- cache: Dict[str, Dict[str, Tensor]] or None
- dropout_rate: Optional[float]
- Returns:
- Tensor or Tuple[Tensor, Dict[str, Dict[str, Tensor]]]
- """
- # We use a cache position index for tracking decoding position.
- if decode:
- name = self._get_name()
- if cache is None:
- cache = {
- name: {
- 'cache_index':
- torch.tensor(0, dtype=torch.long, device=self.pe.device),
- },
- }
- pe = self.pe[0, cache[name]['cache_index'], :]
- cache[name]['cache_index'] += 1
- return F.dropout(x + pe, dropout_rate, self.training), cache
- if inputs_positions is None:
- # normal unpacked case:
- pe = self.pe[:, :x.size(1), :]
- else:
- # for packed data we need to use known position indices:
- pe = self.pe[0, inputs_positions, :]
- return F.dropout(x + pe, dropout_rate, self.training)
-
-
-# TransformerEncoderLayer and TransformerDecoderLayer are taken from:
-# https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/transformer.py
-# Main difference is the use of custom MultiheadAttention modules.
-class TransformerEncoderLayer(nn.Module):
- r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
- This standard encoder layer is based on the paper "Attention Is All You Need".
- Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones,
- Aidan N Gomez, Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all
- you need. In Advances in Neural Information Processing Systems,
- pages 6000-6010. Users may modify or implement in a different way during
- application.
- Args:
- d_model: the number of expected features in the input (default=1024).
- nhead: the number of heads in the multiheadattention models (default=16).
- dim_feedforward: the dimension of the feedforward network model
- (default=1024).
- activation: the activation function of the intermediate layer, can be a
- string ("relu" or "gelu") or a unary callable (default=F.relu).
- layer_norm_eps: the eps value in layer normalization components
- (default=1e-6).
- pre_ln: if ``True``, layer norm is done prior to attention and
- feedforward operations, respectivaly. Otherwise it's done after.
- Default: ``True``.
- Examples::
- >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
- >>> src = torch.rand(32, 10, 512)
- >>> out = encoder_layer(src)
- """
- __constants__ = ['pre_ln']
-
- def __init__(self,
- d_model: int = 1024,
- nhead: int = 16,
- dim_feedforward: int = 1024,
- activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
- glu: bool = False,
- layer_norm_eps: float = 1e-6,
- attention_temp: float = 1.0,
- pre_ln: bool = True,
- device=None,
- dtype=None) -> None:
- factory_kwargs = {'device': device, 'dtype': dtype}
- super().__init__()
- self.self_attn = MultiheadAttention(
- d_model,
- nhead,
- self_attn=True,
- attention_temp=attention_temp,
- bias=False,
- **factory_kwargs)
-
- # Implementation of Feedforward model.
- self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
- self.glu = glu
- if self.glu:
- self.linear_glu = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
- self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
-
- self.pre_ln = pre_ln
- self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
- self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
-
- self.activation = activation
-
- def forward(self,
- src: Tensor,
- src_mask: Optional[Tensor] = None,
- dropout_rate: Optional[float] = 0.0) -> Tensor:
- r"""Pass the input through the encoder layer.
-
- Args:
- src: the sequence to the encoder layer (required).
- src_mask: the mask for the src sequence (optional).
- dropout_rate: the dropout probability value (optional).
- Shape:
- see the docs in Transformer class.
- """
- x = src
- if self.pre_ln:
- x = x + self._sa_block(self.norm1(x), src_mask, dropout_rate)
- x = x + self._ff_block(self.norm2(x), dropout_rate)
- else:
- x = self.norm1(x + self._sa_block(x, src_mask, dropout_rate))
- x = self.norm2(x + self._ff_block(x, dropout_rate))
-
- return x
-
- # Self-attention block:
- def _sa_block(self,
- x: Tensor,
- attn_mask: Optional[Tensor],
- dropout_rate: Optional[float] = 0.0) -> Tensor:
- x, _ = self.self_attn(x, attn_mask=attn_mask, dropout_rate=dropout_rate)
- return F.dropout(x, dropout_rate, training=self.training)
-
- # Feed forward block:
- def _ff_block(self,
- inputs: Tensor,
- dropout_rate: Optional[float] = 0.0) -> Tensor:
- x = self.activation(self.linear1(inputs))
- if self.glu:
- y = self.linear_glu(inputs)
- x = x * y
- x = self.linear2(F.dropout(x, dropout_rate, training=self.training))
- return F.dropout(x, dropout_rate, training=self.training)
-
-
-# Modified to use cache for autoregressive decoding and custom
-# MultiheadAttention modules.
-class TransformerDecoder(nn.Module):
- r"""TransformerDecoder is a stack of N decoder layers
- Args:
- d_model: the number of expected features in the input (default=1024)
- nhead: the number of heads in the multiheadattention models (default=16)
- d_hid: the dimension of the feedforward network model
- (default=1024)
- layer_norm_eps: the eps value in layer normalization components
- (default=1e-6).
- decoder_layer: an instance of the TransformerDecoderLayer() class
- num_layers: the number of sub-decoder-layers in the decoder
- Examples::
- >>> decoder_layer = nn.TransformerDecoderLayer(12, 8)
- >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, 6)
- >>> memory = torch.rand(10, 32, 512)
- >>> tgt = torch.rand(20, 32, 512)
- >>> out = transformer_decoder(tgt, memory)
- """
- __constants__ = ['norm']
-
- def __init__(self,
- d_model,
- nhead,
- d_hid,
- activation,
- glu,
- layer_norm_eps,
- num_layers,
- attention_temp,
- pre_ln):
- super().__init__()
- self.layers = nn.ModuleList([
- TransformerDecoderLayer(
- d_model,
- nhead,
- d_hid,
- activation,
- glu,
- layer_norm_eps=layer_norm_eps,
- attention_temp=attention_temp,
- pre_ln=pre_ln) for _ in range(num_layers)
- ])
- self.num_layers = num_layers
- self.norm = (nn.LayerNorm(d_model, eps=layer_norm_eps) if pre_ln else None)
-
- def forward(self,
- tgt: Tensor,
- memory: Tensor,
- tgt_mask: Optional[Tensor] = None,
- memory_mask: Optional[Tensor] = None,
- decode: bool = False,
- max_len: Optional[int] = None,
- cache: Optional[dict] = None,
- dropout_rate: Optional[float] = 0.0) -> Any:
- r"""Pass the inputs (and mask) through the decoder layer in turn.
- Args:
- tgt: the sequence to the decoder (required).
- memory: the sequence from the last layer of the encoder (required).
- tgt_mask: the mask for the tgt sequence (optional).
- memory_mask: the mask for the memory sequence (optional).
- decode: whether to use cache for autoregressive decoding or not.
- max_len: maximum sequence length, necessary for decoding cache.
- dropout_rate: the dropout probability value (optional)
- Shape:
- see the docs in Transformer class.
- """
- output = tgt
-
- for idx, mod in enumerate(self.layers):
- output, cache = mod(
- output,
- memory,
- tgt_mask=tgt_mask,
- memory_mask=memory_mask,
- decode=decode,
- max_len=max_len,
- cache=cache,
- index=idx,
- dropout_rate=dropout_rate)
-
- if self.norm is not None:
- output = self.norm(output)
-
- if decode:
- return output, cache
- return output
-
-
-# Modified to use cache for autoregressive decoding and custom
-# MultiheadAttention modules.
-class TransformerDecoderLayer(nn.Module):
- r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and
- feedforward network.
- This standard decoder layer is based on the paper "Attention Is All You Need".
- Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones,
- Aidan N Gomez, Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all
- you need. In Advances in Neural Information Processing Systems,
- pages 6000-6010. Users may modify or implement in a different way during
- application.
- Args:
- d_model: the number of expected features in the input (default=1024).
- nhead: the number of heads in the multiheadattention models (default=16).
- dim_feedforward: the dimension of the feedforward network model
- (default=1024).
- activation: the activation function of the intermediate layer, can be a
- string ("relu" or "gelu") or a unary callable (default=F.relu).
- layer_norm_eps: the eps value in layer normalization components
- (default=1e-6).
- pre_ln: if ``True``, layer norm is done prior to self attention,
- multihead attention and feedforward operations, respectivaly.
- Otherwise it's done after. Default: ``True``.
- Examples::
- >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
- >>> memory = torch.rand(32, 10, 512)
- >>> tgt = torch.rand(32, 20, 512)
- >>> out = decoder_layer(tgt, memory)
- """
- __constants__ = ['pre_ln']
-
- def __init__(self,
- d_model: int = 1024,
- nhead: int = 16,
- dim_feedforward: int = 1024,
- activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
- glu: bool = False,
- layer_norm_eps: float = 1e-6,
- pre_ln: bool = True,
- attention_temp: float = 1.0,
- device=None,
- dtype=None) -> None:
- factory_kwargs = {'device': device, 'dtype': dtype}
- super().__init__()
- self.self_attn = MultiheadAttention(
- d_model,
- nhead,
- self_attn=True,
- attention_temp=attention_temp,
- bias=False,
- **factory_kwargs)
- self.multihead_attn = MultiheadAttention(
- d_model,
- nhead,
- self_attn=False,
- attention_temp=attention_temp,
- bias=False,
- **factory_kwargs)
-
- # Implementation of Feedforward model.
- self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
- self.glu = glu
- if self.glu:
- self.linear_glu = nn.Linear(dim_feedforward,
- dim_feedforward,
- **factory_kwargs)
- self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
-
- self.pre_ln = pre_ln
- self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
- self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
- self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
-
- self.activation = activation
-
- def forward( # pylint: disable=arguments-renamed
- self,
- tgt: Tensor,
- memory: Tensor,
- tgt_mask: Optional[Tensor] = None,
- memory_mask: Optional[Tensor] = None,
- decode: bool = False,
- max_len: Optional[int] = None,
- cache: Optional[dict] = None,
- index: Optional[int] = None,
- dropout_rate: Optional[float] = 0.0) -> Any:
- r"""Pass the inputs (and mask) through the decoder layer.
- Args:
- tgt: the sequence to the decoder layer (required).
- memory: the sequence from the last layer of the encoder (required).
- tgt_mask: the mask for the tgt sequence (optional).
- memory_mask: the mask for the memory sequence (optional).
- decode: wether to use cache for autoregressive decoding or not.
- max_len: maximum sequence length, necessary for decoding cache.
- dropout_rate: the dropout probability value (optional)
- Shape:
- see the docs in Transformer class.
- """
- # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
-
- x = tgt
- if self.pre_ln:
- sa_out, cache = self._sa_block(
- self.norm1(x),
- tgt_mask,
- decode=decode,
- max_len=max_len,
- cache=cache,
- index=index,
- dropout_rate=dropout_rate)
- x = x + sa_out
- x = x + self._mha_block(self.norm2(x), memory, memory_mask, dropout_rate)
- x = x + self._ff_block(self.norm3(x), dropout_rate)
- else:
- sa_out, cache = self._sa_block(
- x,
- tgt_mask,
- decode=decode,
- max_len=max_len,
- cache=cache,
- index=index,
- dropout_rate=dropout_rate)
- x = self.norm1(x + sa_out)
- x = self.norm2(x + self._mha_block(x, memory, memory_mask, dropout_rate))
- x = self.norm3(x + self._ff_block(x, dropout_rate))
-
- return x, cache
-
- # Self-attention block:
- def _sa_block( # pylint: disable=arguments-renamed
- self,
- x: Tensor,
- attn_mask: Optional[Tensor],
- decode: bool = False,
- max_len: Optional[int] = None,
- cache: Optional[dict] = None,
- index: Optional[int] = None,
- dropout_rate: Optional[float] = 0.0) -> Any:
- x, cache = self.self_attn(
- x,
- attn_mask=attn_mask,
- decode=decode,
- max_len=max_len,
- cache=cache,
- index=index,
- dropout_rate=dropout_rate)
- return F.dropout(x, dropout_rate, self.training), cache
-
- # Multihead attention block:
- def _mha_block(self, x: Tensor, mem: Tensor,
- attn_mask: Optional[Tensor],
- dropout_rate: Optional[float] = 0.0) -> Tensor:
- x, _ = self.multihead_attn(
- x,
- mem,
- attn_mask=attn_mask,
- dropout_rate=dropout_rate)
- return F.dropout(x, dropout_rate, self.training)
-
- # Feed forward block.
- def _ff_block(self, inputs: Tensor,
- dropout_rate: Optional[float] = 0.0) -> Tensor:
- x = self.activation(self.linear1(inputs))
- if self.glu:
- y = self.linear_glu(inputs)
- x = x * y
- x = self.linear2(F.dropout(x, dropout_rate, self.training))
- return F.dropout(x, dropout_rate, self.training)
-
-
-class MultiheadAttention(nn.Module):
- r"""Allows the model to jointly attend to information
- from different representation subspaces. Supports self-attention and
- encoder-decoder attention.
- See `Attention Is All You Need `_.
- .. math::
- \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
- where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
- Args:
- embed_dim: Total dimension of the model.
- num_heads: Number of parallel attention heads. Note that ``embed_dim`` will
- be split across ``num_heads`` (i.e. each head will have dimension
- ``embed_dim // num_heads``).
- self_attn: Whether self attention or encoder-decoder attention is used.
- Default: ``True``.
- bias: If specified, adds bias to input / output projection layers.
- Default: ``False``.
- device: The device of the module.
- dtype: The dtype of the module.
- Examples::
- >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
- >>> attn_output, cache = multihead_attn(x)
- """
-
- def __init__(self,
- embed_dim: int,
- num_heads: int,
- self_attn: bool = True,
- attention_temp: float = 1.0,
- bias: bool = False,
- device: Optional[torch.device] = None,
- dtype: Optional[torch.dtype] = None) -> None:
- super().__init__()
- self.embed_dim = embed_dim
- self.num_heads = num_heads
- self.self_attn = self_attn
- self.head_dim = embed_dim // num_heads
- self.attention_temp = attention_temp
- assert self.head_dim * num_heads == self.embed_dim, \
- 'embed_dim must be divisible by num_heads.'
-
- factory_kwargs = {'device': device, 'dtype': dtype}
- if self_attn:
- # Self-attention.
- self.in_proj = nn.Linear(
- embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
- else:
- # Encoder-decoder attention.
- self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
- self.kv_proj = nn.Linear(
- embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs)
- self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
-
- self._reset_parameters()
-
- def _reset_parameters(self):
- """Initiate parameters in the MultiheadAttention module."""
- for module in self.modules():
- if isinstance(module, nn.Linear):
- xavier_uniform_(module.weight)
- if module.bias is not None:
- normal_(module.bias, std=1e-6)
-
- def forward(self,
- x: Tensor,
- mem: Optional[Tensor] = None,
- attn_mask: Optional[Tensor] = None,
- decode: bool = False,
- max_len: Optional[int] = None,
- cache: Optional[dict] = None,
- index: Optional[int] = None,
- dropout_rate: Optional[float] = 0.0) -> Any: # TODO: (nico) remove default?!
- r"""
- Args:
- x: Batch of input sequences of shape
- (batch size, sequence length, embedding dimensionality) for self
- attention mechanism. See "Attention Is All You Need" for more details.
- mem: Batch of input sequences of shape
- (batch size, sequence length, embedding dimensionality) for
- encoder-decoder attention. See "Attention Is All You Need" for more
- details.
- attn_mask: If specified, a 2D or 3D mask preventing attention to certain
- positions. Must be of shape :math:`(L, S)` or
- :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the
- batch size, :math:`L` is the target sequence length, and :math:`S`
- is the source sequence length. A 2D mask will be broadcasted across
- the batch while a 3D mask allows for a different mask for each entry
- in the batch. Binary, byte, and float masks are supported.
- For a binary mask, a ``True`` value indicates that the
- corresponding position is not allowed to attend. For a byte mask,
- a non-zero value indicates that the corresponding position is not
- allowed to attend. For a float mask, the mask values will be added to
- the attention weight.
- decode: wether to use cache for autoregressive decoding or not.
- max_len: maximum sequence length, necessary for decoding cache.
- cache: cache dictionary for autoregressive decoding.
- index: index of the current decoding step, necessary for decoding cache.
- dropout_rate: dropout probability on ``attn_output_weights``.
- Outputs:
- - **attn_output** - Attention outputs of shape :math:`(N, L, E)`, where
- :math:`L` is the target sequence length, :math:`N` is the batch size,
- and :math:`E` is the embedding dimension ``embed_dim``.
- - **cache** - For autoregressive decoding.
- """
- # Shape: (batch size, sequence length, embedding dimensionality)
- bsz, seq_len, embed_dim = x.size()
- # In projection.
- if self.self_attn:
- q, k, v = self.in_proj(x).split(self.embed_dim, dim=2)
- else:
- q = self.q_proj(x)
- k, v = self.kv_proj(mem).split(self.embed_dim, dim=2)
- # This is 1 (!= seq_len) during autoreregressive decoding.
- tgt_len = q.size(1)
-
- # During fast autoregressive decoding, we feed one position at a time,
- # and cache the keys and values step by step.
- name = f'decoder.layers.{index}.self_attn'
- loc_cache = cache[name] if decode and name in cache else None
- if decode:
- if loc_cache is None:
- loc_cache = {
- 'cached_key':
- torch.zeros((bsz, max_len, embed_dim),
- dtype=k.dtype,
- device=k.device),
- 'cached_value':
- torch.zeros((bsz, max_len, embed_dim),
- dtype=v.dtype,
- device=v.device),
- 'cache_index':
- torch.tensor(0, dtype=torch.long, device=k.device),
- }
- cached_key = loc_cache['cached_key']
- cached_value = loc_cache['cached_value']
- cache_index = loc_cache['cache_index']
- # Shape check of cached keys against query input.
- expected_shape = (bsz, 1, embed_dim)
- if expected_shape != x.shape:
- raise ValueError('Autoregressive cache shape error, expected query '
- f'shape {expected_shape} instead got {x.shape}.')
- # Update key, value caches with our new 1d spatial slices.
- cached_key[:, cache_index:cache_index + 1, :] = k
- cached_value[:, cache_index:cache_index + 1, :] = v
- k = cached_key
- v = cached_value
- cache_index += 1
- # Causal mask for cached decoder self-attention:
- # our single query position should only attend to those key
- # positions that have already been generated and cached,
- # not the remaining zero elements.
- if attn_mask is not None:
- raise ValueError('Attention mask has to be None for decode == True.')
- attn_mask = (torch.arange(max_len, device=k.device) >=
- cache_index).reshape(1, max_len)
-
- # Update sequence length to account for complete sequence.
- seq_len = k.size(1)
-
- # Rearrange q, k, v for multihead attention.
- q = q.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
- k = k.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
- v = v.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
-
- # Check dtype and shape of attention mask.
- if not decode and attn_mask is not None:
- assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, \
- f'Float and bool dtypes are supported, not {attn_mask.dtype}.'
- # Ensure attn_mask's dim is 3.
- if attn_mask.dim() == 3:
- correct_3d_size = (bsz * self.num_heads, tgt_len, seq_len)
- if attn_mask.shape != correct_3d_size:
- raise RuntimeError(f'The shape of attn_mask is {attn_mask.shape}, '
- f'but should be {correct_3d_size}.')
- else:
- raise RuntimeError(
- f"attn_mask's dimension {attn_mask.dim()} is not supported")
- # Reshape attention mask to be consistent with q, k, v.
- attn_mask = attn_mask.view(bsz, self.num_heads, tgt_len, seq_len)
-
- # Convert attention mask to float.
- if attn_mask is not None and attn_mask.dtype == torch.bool:
- new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
- new_attn_mask.masked_fill_(attn_mask, -1e10)
- attn_mask = new_attn_mask
-
- # Adjust dropout_rate probability.
- attn_dropout_rate = dropout_rate if self.training else 0.0
-
- # Calculate attention.
- q = self.attention_temp * q
- attn_output = torch.nn.functional.scaled_dot_product_attention(
- q, k, v, attn_mask, attn_dropout_rate)
- # Rearrange for output projection.
- attn_output = attn_output.transpose(1, 2).contiguous().view(
- bsz, tgt_len, embed_dim)
- # Output projection.
- attn_output = self.out_proj(attn_output)
-
- if decode:
- cache[name] = loc_cache
-
- return attn_output, cache
From f7d99a62670e8c525eef295f423b47f2026f5a38 Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Tue, 17 Jun 2025 21:17:19 +0000
Subject: [PATCH 060/123] fixes
---
algoperf/jax_utils.py | 8 ++++----
.../librispeech_conformer/librispeech_jax/models.py | 9 ++-------
2 files changed, 6 insertions(+), 11 deletions(-)
diff --git a/algoperf/jax_utils.py b/algoperf/jax_utils.py
index 3ca3f1bfc..c4904dc75 100644
--- a/algoperf/jax_utils.py
+++ b/algoperf/jax_utils.py
@@ -14,8 +14,8 @@ class Dropout(Module):
"""Create a dropout layer.
Forked from https://flax-linen.readthedocs.io/en/latest/_modules/flax/linen/stochastic.html#Dropout.
The reference dropout implementation is modified support changes to dropout rate during training by:
- 1) adding rate argument to the __call__ method
- 2) removing the if-else condition to check for edge cases, which will trigger a recompile for jitted code
+ 1) adding rate argument to the __call__ method.
+ 2) removing the if-else condition to check for edge cases, which will trigger a recompile for jitted code.
.. note::
When using :meth:`Module.apply() `, make sure
@@ -82,8 +82,8 @@ def __call__(
deterministic = merge_param("deterministic", self.deterministic, deterministic)
# Override self.rate if rate is passed to __call__
- if not (self.rate is not None and rate is not None):
- rate = merge_param("rate", self.rate, rate)
+ if rate is None:
+ rate = self.rate
if self.legacy:
if rate == 0.0:
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py
index f7beed914..0de6b1449 100644
--- a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py
+++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py
@@ -221,12 +221,7 @@ def __call__(self, inputs, padding_mask=None, train=False, dropout_rate=DROPOUT_
inputs)
inputs = inputs * padding_mask
- if config.feed_forward_residual_dropout_rate is None:
- feed_forward_residual_dropout_rate = 0.1
- else:
- feed_forward_residual_dropout_rate = (
- config.feed_forward_residual_dropout_rate)
- inputs = Dropout(rate=feed_forward_residual_dropout_rate)(
+ inputs = Dropout(rate=dropout_rate)(
inputs, deterministic=not train)
return inputs
@@ -401,7 +396,7 @@ def __call__(self, inputs, paddings, train, dropout_rate=DROPOUT_RATE):
use_bias=True,
broadcast_dropout=False,
attention_fn=attention_fn,
- dropout_rate=config.attention_dropout_rate,
+ dropout_rate=dropout_rate,
deterministic=not train)(
inputs_q=inputs, mask=attention_mask)
From 3a41559dea66881a6264d2e1c96ff4d58a530353 Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Wed, 18 Jun 2025 18:56:24 +0000
Subject: [PATCH 061/123] fix reference_algorithm_tests.py
---
tests/reference_algorithm_tests.py | 7 +++----
1 file changed, 3 insertions(+), 4 deletions(-)
diff --git a/tests/reference_algorithm_tests.py b/tests/reference_algorithm_tests.py
index c4ca514a8..58a4a5ddc 100644
--- a/tests/reference_algorithm_tests.py
+++ b/tests/reference_algorithm_tests.py
@@ -184,12 +184,11 @@ def __init__(self):
if 'librispeech' in workload_name:
self.tokenizer = _FakeTokenizer()
- def init_model_fn(self, rng, dropout_rate=None, aux_dropout_rate=None):
+ def init_model_fn(self, rng):
# pylint: disable=line-too-long
if not (FLAGS.identical and
os.path.exists(f'tests/modeldiffs/{workload_name}/compare.py')):
- return super().init_model_fn(
- rng, dropout_rate=dropout_rate, aux_dropout_rate=aux_dropout_rate)
+ return super().init_model_fn(rng)
if framework == 'jax':
compare_module = importlib.import_module(
f'tests.modeldiffs.{workload_name}.compare')
@@ -201,7 +200,7 @@ def init_model_fn(self, rng, dropout_rate=None, aux_dropout_rate=None):
return (FrozenDict(**jax_utils.replicate(jax_params)),
FrozenDict(**jax_utils.replicate(model_state))
if model_state is not None else model_state)
- return super().init_model_fn([0], dropout_rate=0.0, aux_dropout_rate=0.0)
+ return super().init_model_fn([0])
@property
def num_eval_train_examples(self):
From 7c430227152b8951fb21454960f71169ab00eb09 Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Wed, 18 Jun 2025 20:12:45 +0000
Subject: [PATCH 062/123] fixes to ogbg and fastmri
---
algoperf/workloads/fastmri/fastmri_jax/workload.py | 3 ++-
algoperf/workloads/ogbg/ogbg_jax/models.py | 11 +++++------
2 files changed, 7 insertions(+), 7 deletions(-)
diff --git a/algoperf/workloads/fastmri/fastmri_jax/workload.py b/algoperf/workloads/fastmri/fastmri_jax/workload.py
index acdf077e1..b8067cbad 100644
--- a/algoperf/workloads/fastmri/fastmri_jax/workload.py
+++ b/algoperf/workloads/fastmri/fastmri_jax/workload.py
@@ -11,6 +11,7 @@
from algoperf import param_utils
from algoperf import spec
import algoperf.random_utils as prng
+from algoperf.workloads.fastmri.fastmri_jax.models import DROPOUT_RATE
from algoperf.workloads.fastmri.fastmri_jax.models import UNet
from algoperf.workloads.fastmri.fastmri_jax.ssim import ssim
from algoperf.workloads.fastmri.workload import BaseFastMRIWorkload
@@ -52,7 +53,7 @@ def model_fn(
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool,
- dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ dropout_rate: float = DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del update_batch_norm
train = mode == spec.ForwardPassMode.TRAIN
diff --git a/algoperf/workloads/ogbg/ogbg_jax/models.py b/algoperf/workloads/ogbg/ogbg_jax/models.py
index 59d989284..d51ca2f20 100644
--- a/algoperf/workloads/ogbg/ogbg_jax/models.py
+++ b/algoperf/workloads/ogbg/ogbg_jax/models.py
@@ -18,7 +18,7 @@ def make_fn(inputs):
return make_fn
-def _make_mlp(hidden_dims, dropout, activation_fn):
+def _make_mlp(hidden_dims, activation_fn, train, dropout_rate=DROPOUT_RATE):
"""Creates a MLP with specified dimensions."""
@jraph.concatenated_args
@@ -28,7 +28,7 @@ def make_fn(inputs):
x = nn.Dense(features=dim)(x)
x = nn.LayerNorm()(x)
x = activation_fn(x)
- x = dropout(x)
+ x = Dropout(rate=dropout_rate, deterministic=not train)(x, rate=dropout_rate)
return x
return make_fn
@@ -47,7 +47,6 @@ class GNN(nn.Module):
@nn.compact
def __call__(self, graph, train, dropout_rate=DROPOUT_RATE):
- dropout = Dropout(dropout_rate, deterministic=not train)(dropout_rate)
graph = graph._replace(
globals=jnp.zeros([graph.n_node.shape[0], self.num_outputs]))
@@ -70,11 +69,11 @@ def __call__(self, graph, train, dropout_rate=DROPOUT_RATE):
for _ in range(self.num_message_passing_steps):
net = jraph.GraphNetwork(
update_edge_fn=_make_mlp(
- self.hidden_dims, dropout=dropout, activation_fn=activation_fn),
+ self.hidden_dims, activation_fn=activation_fn, train=train, dropout_rate=dropout_rate),
update_node_fn=_make_mlp(
- self.hidden_dims, dropout=dropout, activation_fn=activation_fn),
+ self.hidden_dims, activation_fn=activation_fn, train=train, dropout_rate=dropout_rate),
update_global_fn=_make_mlp(
- self.hidden_dims, dropout=dropout, activation_fn=activation_fn))
+ self.hidden_dims, activation_fn=activation_fn, train=train, dropout_rate=dropout_rate))
graph = net(graph)
From 894f4fb50f5bfdf0e4d2e197cf090e507a05fc15 Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Wed, 18 Jun 2025 20:29:37 +0000
Subject: [PATCH 063/123] fixes to fastmri and deepspeech
---
algoperf/workloads/fastmri/fastmri_jax/workload.py | 11 +++++------
.../librispeech_conformer/librispeech_jax/models.py | 2 +-
2 files changed, 6 insertions(+), 7 deletions(-)
diff --git a/algoperf/workloads/fastmri/fastmri_jax/workload.py b/algoperf/workloads/fastmri/fastmri_jax/workload.py
index b8067cbad..ccf9c6bad 100644
--- a/algoperf/workloads/fastmri/fastmri_jax/workload.py
+++ b/algoperf/workloads/fastmri/fastmri_jax/workload.py
@@ -58,12 +58,11 @@ def model_fn(
del update_batch_norm
train = mode == spec.ForwardPassMode.TRAIN
- if train:
- logits = self._model.apply({'params': params},
- augmented_and_preprocessed_input_batch['inputs'],
- rngs={'dropout': rng},
- train=train,
- dropout_rate=dropout_rate)
+ logits = self._model.apply({'params': params},
+ augmented_and_preprocessed_input_batch['inputs'],
+ rngs={'dropout': rng},
+ train=train,
+ dropout_rate=dropout_rate)
return logits, None
# Does NOT apply regularization, which is left to the submitter to do in
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py
index 0de6b1449..366e42195 100644
--- a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py
+++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py
@@ -642,7 +642,7 @@ def __call__(self,
train,
update_batch_norm: Optional[bool] = None,
use_running_average_bn: Optional[bool] = None,
- dropout_rate: float = DROPOUT_RATE:
+ dropout_rate: float = DROPOUT_RATE):
config = self.config
outputs = inputs
From 0bcf484282777f39d01b64e41ebba773aef1c913 Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Wed, 18 Jun 2025 20:36:49 +0000
Subject: [PATCH 064/123] fixes to conformer vit
---
algoperf/workloads/imagenet_vit/imagenet_jax/models.py | 4 ++--
.../librispeech_conformer/librispeech_jax/models.py | 3 ---
.../librispeech_conformer/librispeech_jax/workload.py | 3 ---
.../librispeech_deepspeech/librispeech_jax/models.py | 6 +-----
algoperf/workloads/wmt/wmt_jax/models.py | 2 +-
5 files changed, 4 insertions(+), 14 deletions(-)
diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py
index 7c5d7bd26..091a3473e 100644
--- a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py
+++ b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py
@@ -125,14 +125,14 @@ class Encoder(nn.Module):
depth: int
mlp_dim: Optional[int] = None # Defaults to 4x input dim.
num_heads: int = 12
- dropout_rate: float = 0.0
use_glu: bool = False
use_post_layer_norm: bool = False
@nn.compact
def __call__(self,
x: spec.Tensor,
- train: bool = True) -> spec.Tensor:
+ train: bool = True,
+ dropout_rate: float = DROPOUT_RATE) -> spec.Tensor:
# Input Encoder
for lyr in range(self.depth):
block = Encoder1DBlock(
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py
index 366e42195..bf0eb813e 100644
--- a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py
+++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py
@@ -38,7 +38,6 @@ class ConformerConfig:
encoder_dim: int = 512
num_attention_heads: int = 8
num_encoder_layers: int = 4
- dropout_rate: float = DROPOUT_RATE
convolution_kernel_size: int = 5
feed_forward_expansion_factor: int = 4
freq_mask_count: int = 2
@@ -191,8 +190,6 @@ class FeedForwardModule(nn.Module):
@nn.compact
def __call__(self, inputs, padding_mask=None, train=False, dropout_rate=DROPOUT_RATE):
config = self.config
- if dropout_rate is None:
- dropout_rate = config.dropout_rate
inputs = LayerNorm(dim=config.encoder_dim)(inputs)
inputs = nn.Dense(
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py
index 8d966ef87..eec707e5f 100644
--- a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py
+++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py
@@ -71,9 +71,6 @@ def init_model_fn(
else:
activation_function_name = 'swish'
model_config = models.ConformerConfig(
- attention_residual_dropout_rate=dropout_rate,
- feed_forward_residual_dropout_rate=dropout_rate,
- input_dropout_rate=dropout_rate,
use_specaug=self.use_specaug,
attention_temperature=self.attention_temperature,
use_post_layer_norm=self.use_post_layer_norm,
diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py
index 84ba58ee2..b47b1359a 100644
--- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py
+++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py
@@ -105,12 +105,8 @@ def __call__(self, inputs, output_paddings, train, dropout_rate=DROPOUT_RATE):
kernel_init=nn.initializers.xavier_uniform())(
outputs)
- if config.input_dropout_rate is None:
- input_dropout_rate = 0.1
- else:
- input_dropout_rate = config.input_dropout_rate
outputs = Dropout(
- rate=input_dropout_rate, deterministic=not train)(
+ rate=dropout_rate, deterministic=not train)(
outputs, rate=dropout_rate)
return outputs, output_paddings
diff --git a/algoperf/workloads/wmt/wmt_jax/models.py b/algoperf/workloads/wmt/wmt_jax/models.py
index e262214ac..38f76db80 100644
--- a/algoperf/workloads/wmt/wmt_jax/models.py
+++ b/algoperf/workloads/wmt/wmt_jax/models.py
@@ -574,7 +574,7 @@ def decode(
targets_positions=targets_positions,
decoder_mask=decoder_mask,
encoder_decoder_mask=encoder_decoder_mask,
- dropout_rate=droput_rate)
+ dropout_rate=dropout_rate)
return logits.astype(self.config.dtype)
def __call__(self,
From 73c2276cb1907534f16b76f82e95c95400d04f8f Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Wed, 18 Jun 2025 20:48:56 +0000
Subject: [PATCH 065/123] conformer and vit fix for dropout refactor
---
algoperf/workloads/imagenet_vit/imagenet_jax/models.py | 3 +--
.../librispeech_conformer/librispeech_jax/models.py | 7 +++----
2 files changed, 4 insertions(+), 6 deletions(-)
diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py
index 091a3473e..a78a5e791 100644
--- a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py
+++ b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py
@@ -226,8 +226,7 @@ def __call__(self,
num_heads=self.num_heads,
use_glu=self.use_glu,
use_post_layer_norm=self.use_post_layer_norm,
- name='Transformer',
- dropout_rate=dropout_rate)(
+ name='Transformer',)(
x, train=not train, dropout_rate=dropout_rate)
if self.use_map:
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py
index bf0eb813e..1c2d79e15 100644
--- a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py
+++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py
@@ -206,8 +206,8 @@ def __call__(self, inputs, padding_mask=None, train=False, dropout_rate=DROPOUT_
'config.activation_function_name values, recieved '
f'{config.activation_function_name}')
inputs = activation_fn(inputs)
- inputs = Dropout(rate=config.feed_forward_dropout_rate)(
- inputs, deterministic=not train)
+ inputs = Dropout(rate=dropout_rate)(
+ inputs, deterministic=not train, rate=dropout_rate)
inputs = inputs * padding_mask
@@ -665,8 +665,7 @@ def __call__(self,
outputs, output_paddings = self.specaug(outputs, output_paddings)
outputs, output_paddings = Subsample(
- encoder_dim=config.encoder_dim,
- dropout_rate=dropout_rate)(
+ encoder_dim=config.encoder_dim,)(
outputs, output_paddings, train, dropout_rate=dropout_rate)
# Run the conformer encoder layers.
From 5ff94d23242a6613dc5d62579a9a4fe44d017eec Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Wed, 18 Jun 2025 21:12:10 +0000
Subject: [PATCH 066/123] wmt fixes
---
.../imagenet_vit/imagenet_jax/models.py | 54 +++++++++----------
algoperf/workloads/wmt/wmt_jax/workload.py | 2 +-
2 files changed, 27 insertions(+), 29 deletions(-)
diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py
index a78a5e791..716bd4239 100644
--- a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py
+++ b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py
@@ -85,7 +85,7 @@ def __call__(self,
deterministic=train,
name='MultiHeadDotProductAttention_1')(
y)
- y = Dropout(dropout_rate)(y, train, dropout_rate=dropout_rate)
+ y = Dropout(dropout_rate)(y, train, rate=dropout_rate)
x = x + y
y = nn.LayerNorm(name='LayerNorm_2')(x)
@@ -121,33 +121,31 @@ def __call__(self,
class Encoder(nn.Module):
- """Transformer Model Encoder for sequence to sequence translation."""
- depth: int
- mlp_dim: Optional[int] = None # Defaults to 4x input dim.
- num_heads: int = 12
- use_glu: bool = False
- use_post_layer_norm: bool = False
-
- @nn.compact
- def __call__(self,
- x: spec.Tensor,
- train: bool = True,
- dropout_rate: float = DROPOUT_RATE) -> spec.Tensor:
- # Input Encoder
- for lyr in range(self.depth):
- block = Encoder1DBlock(
- name=f'encoderblock_{lyr}',
- mlp_dim=self.mlp_dim,
- num_heads=self.num_heads,
- use_glu=self.use_glu,
- use_post_layer_norm=self.use_post_layer_norm,
- dropout_rate=dropout_rate)(
- dropout_rate=dropout_rate)
- x = block(x, train)
- if not self.use_post_layer_norm:
- return nn.LayerNorm(name='encoder_layernorm')(x)
- else:
- return x
+ """Transformer Model Encoder for sequence to sequence translation."""
+
+ depth: int
+ mlp_dim: Optional[int] = None # Defaults to 4x input dim.
+ num_heads: int = 12
+ use_glu: bool = False
+ use_post_layer_norm: bool = False
+
+ @nn.compact
+ def __call__(
+ self, x: spec.Tensor, train: bool = True, dropout_rate: float = DROPOUT_RATE
+ ) -> spec.Tensor:
+ # Input Encoder
+ for lyr in range(self.depth):
+ x = Encoder1DBlock(
+ name=f"encoderblock_{lyr}",
+ mlp_dim=self.mlp_dim,
+ num_heads=self.num_heads,
+ use_glu=self.use_glu,
+ use_post_layer_norm=self.use_post_layer_norm,
+ )(x, train=train, dropout_rate=dropout_rate)
+ if not self.use_post_layer_norm:
+ return nn.LayerNorm(name="encoder_layernorm")(x)
+ else:
+ return x
class MAPHead(nn.Module):
diff --git a/algoperf/workloads/wmt/wmt_jax/workload.py b/algoperf/workloads/wmt/wmt_jax/workload.py
index 9548f5b7e..24d4852b8 100644
--- a/algoperf/workloads/wmt/wmt_jax/workload.py
+++ b/algoperf/workloads/wmt/wmt_jax/workload.py
@@ -278,7 +278,7 @@ def model_fn(
inputs_segmentation=inputs_segmentations,
targets_segmentation=targets_segmentations,
rngs={'dropout': rng},
- dropout_rate=None)
+ dropout_rate=dropout_rate)
return logits_batch, None
def _normalize_eval_metrics(
From 9090e43e1970454dd250f277a7a26cbc4ff6f8dd Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Wed, 18 Jun 2025 21:16:14 +0000
Subject: [PATCH 067/123] fix linting
---
algoperf/pytorch_utils.py | 7 ++-
.../criteo1tb/criteo1tb_pytorch/models.py | 7 ++-
.../criteo1tb/criteo1tb_pytorch/workload.py | 7 +--
.../fastmri/fastmri_pytorch/models.py | 19 ++----
.../fastmri/fastmri_pytorch/workload.py | 7 +--
.../imagenet_pytorch/workload.py | 7 +--
.../imagenet_vit/imagenet_pytorch/models.py | 18 +++---
.../imagenet_vit/imagenet_pytorch/workload.py | 12 ++--
.../librispeech_pytorch/models.py | 19 +++---
.../librispeech_pytorch/workload.py | 7 +--
.../librispeech_pytorch/models.py | 3 +-
.../librispeech_pytorch/workload.py | 19 +++---
.../workloads/ogbg/ogbg_pytorch/models.py | 10 +--
.../workloads/ogbg/ogbg_pytorch/workload.py | 12 ++--
algoperf/workloads/wmt/wmt_pytorch/models.py | 63 +++++++++++--------
.../workloads/wmt/wmt_pytorch/workload.py | 7 +--
.../test_model_equivalence.py | 23 +++----
.../fastmri_pytorch/test_model_equivalence.py | 14 +++--
.../test_model_equivalence.py | 16 +++--
.../test_model_equivalence.py | 24 +++----
.../test_model_equivalence.py | 24 +++----
.../ogbg_pytorch/test_model_equivalence.py | 14 +++--
.../wmt_pytorch/test_model_equivalence.py | 21 ++++---
23 files changed, 193 insertions(+), 167 deletions(-)
diff --git a/algoperf/pytorch_utils.py b/algoperf/pytorch_utils.py
index 4af77088e..bae26dea0 100644
--- a/algoperf/pytorch_utils.py
+++ b/algoperf/pytorch_utils.py
@@ -6,8 +6,8 @@
import tensorflow as tf
import torch
from torch import Tensor
-import torch.nn as nn
import torch.distributed as dist
+import torch.nn as nn
import torch.nn.functional as F
from algoperf import spec
@@ -84,6 +84,7 @@ def update_batch_norm_fn(module: spec.ParameterContainer,
class CustomDropout(nn.Module):
"""A module around torch.nn.functional.dropout."""
+
def __init__(self):
super().__init__()
self._supports_custom_dropout = True
@@ -94,6 +95,7 @@ def forward(self, input: Tensor, p: float) -> Tensor:
class CustomDropout2d(nn.Module):
"""A module around torch.nn.functional.dropout2d."""
+
def __init__(self):
super().__init__()
self._supports_custom_dropout = True
@@ -104,6 +106,7 @@ def forward(self, input: Tensor, p: float) -> Tensor:
class SequentialWithDropout(nn.Sequential):
"""Sequential of modules with dropout."""
+
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._supports_custom_dropout = True
@@ -113,5 +116,5 @@ def forward(self, x: Tensor, p: float) -> Tensor:
if getattr(module, '_supports_custom_dropout', False):
x = module(x, p)
else:
- x = module(x)
+ x = module(x)
return x
diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models.py
index f0653a665..7574de3a7 100644
--- a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models.py
+++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models.py
@@ -5,13 +5,15 @@
import torch
from torch import nn
-from algoperf.pytorch_utils import CustomDropout, SequentialWithDropout
+from algoperf.pytorch_utils import CustomDropout
+from algoperf.pytorch_utils import SequentialWithDropout
DROPOUT_RATE = 0.0
class DenseBlock(nn.Module):
"""Dense block with optional residual connection.""" ""
+
def __init__(self, module, resnet=False):
super().__init__()
self.module = module
@@ -23,6 +25,7 @@ def forward(self, x):
class DenseBlockWithDropout(nn.Module):
"""Dense block with optional residual connection and support for dropout."""
+
def __init__(self, module, resnet=False):
super().__init__()
self.module = module
@@ -30,7 +33,7 @@ def __init__(self, module, resnet=False):
self._supports_custom_dropout = True
def forward(self, x, p):
- return self.module(x, p) + x if self.resnet else self.module(x, p)
+ return self.module(x, p) + x if self.resnet else self.module(x, p)
class DotInteract(nn.Module):
diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py
index 48c6592f2..69f24c69d 100644
--- a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py
+++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py
@@ -65,9 +65,7 @@ def loss_fn(
'per_example': per_example_losses,
}
- def init_model_fn(
- self,
- rng: spec.RandomState) -> spec.ModelInitState:
+ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
torch.random.manual_seed(rng[0])
# Disable cudnn benchmark to avoid OOM errors.
torch.backends.cudnn.benchmark = False
@@ -104,7 +102,8 @@ def model_fn(
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool,
- dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ dropout_rate: float = models.DROPOUT_RATE
+ ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del rng
del update_batch_norm
diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/models.py b/algoperf/workloads/fastmri/fastmri_pytorch/models.py
index 0b8ac5499..3e7d7671c 100644
--- a/algoperf/workloads/fastmri/fastmri_pytorch/models.py
+++ b/algoperf/workloads/fastmri/fastmri_pytorch/models.py
@@ -13,7 +13,8 @@
from torch.nn import functional as F
from algoperf import init_utils
-from algoperf.pytorch_utils import CustomDropout2d, SequentialWithDropout
+from algoperf.pytorch_utils import CustomDropout2d
+from algoperf.pytorch_utils import SequentialWithDropout
DROPOUT_RATE = 0.0
@@ -39,12 +40,8 @@ def __init__(self,
self.num_channels = num_channels
self.num_pool_layers = num_pool_layers
- self.down_sample_layers = nn.ModuleList([
- ConvBlock(in_chans,
- num_channels,
- use_tanh,
- use_layer_norm)
- ])
+ self.down_sample_layers = nn.ModuleList(
+ [ConvBlock(in_chans, num_channels, use_tanh, use_layer_norm)])
ch = num_channels
for _ in range(num_pool_layers - 1):
self.down_sample_layers.append(
@@ -58,8 +55,7 @@ def __init__(self,
for _ in range(num_pool_layers - 1):
self.up_transpose_conv.append(
TransposeConvBlock(ch * 2, ch, use_tanh, use_layer_norm))
- self.up_conv.append(
- ConvBlock(ch * 2, ch, use_tanh, use_layer_norm))
+ self.up_conv.append(ConvBlock(ch * 2, ch, use_tanh, use_layer_norm))
ch //= 2
self.up_transpose_conv.append(
@@ -74,10 +70,7 @@ def __init__(self,
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
init_utils.pytorch_default_init(m)
- def forward(
- self,
- x: Tensor,
- dropout_rate: float = DROPOUT_RATE) -> Tensor:
+ def forward(self, x: Tensor, dropout_rate: float = DROPOUT_RATE) -> Tensor:
stack = []
output = x
diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/workload.py b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py
index 9b96230fc..1adbb57ca 100644
--- a/algoperf/workloads/fastmri/fastmri_pytorch/workload.py
+++ b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py
@@ -106,9 +106,7 @@ def _build_input_queue(self,
batch['volume_max'] = aux_tensors[2][RANK]
yield batch
- def init_model_fn(
- self,
- rng: spec.RandomState) -> spec.ModelInitState:
+ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
torch.random.manual_seed(rng[0])
model = UNet(
num_pool_layers=self.num_pool_layers,
@@ -136,7 +134,8 @@ def model_fn(
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool,
- dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ dropout_rate: float = models.DROPOUT_RATE
+ ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del rng
del update_batch_norm
diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py
index f28eb1762..285ba3b4b 100644
--- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py
+++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py
@@ -154,9 +154,7 @@ def _build_dataset(
return dataloader
- def init_model_fn(
- self,
- rng: spec.RandomState) -> spec.ModelInitState:
+ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
torch.random.manual_seed(rng[0])
if self.use_silu and self.use_gelu:
@@ -190,7 +188,8 @@ def model_fn(
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool,
- dropout_rate: float = 0.0) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ dropout_rate: float = 0.0
+ ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del rng
del dropout_rate
diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py
index 60e09edb5..9453780d0 100644
--- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py
+++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py
@@ -14,7 +14,8 @@
from algoperf import init_utils
from algoperf import spec
-from algoperf.workloads.wmt.wmt_pytorch.models_dropout import MultiheadAttention
+from algoperf.workloads.wmt.wmt_pytorch.models_dropout import \
+ MultiheadAttention
DROPOUT_RATE = 0.0
@@ -86,9 +87,7 @@ def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor:
class SelfAttention(nn.Module):
"""Self-attention special case of multi-head dot-product attention."""
- def __init__(self,
- width: int,
- num_heads: int = 8) -> None:
+ def __init__(self, width: int, num_heads: int = 8) -> None:
super().__init__()
self.width = width
@@ -161,9 +160,7 @@ def __init__(self,
self.self_attention1 = SelfAttention(self.width, self.num_heads)
self.layer_norm2 = nn.LayerNorm(self.width, eps=1e-6)
self.mlp3 = MlpBlock(
- width=self.width,
- mlp_dim=self.mlp_dim,
- use_glu=self.use_glu)
+ width=self.width, mlp_dim=self.mlp_dim, use_glu=self.use_glu)
def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor:
@@ -345,10 +342,9 @@ def reset_parameters(self) -> None:
def get_posemb(self, x: spec.Tensor) -> spec.Tensor:
return posemb_sincos_2d(x).type(self.dtype)
- def forward(
- self,
- x: spec.Tensor,
- dropout_rate: float = DROPOUT_RATE) -> spec.Tensor:
+ def forward(self,
+ x: spec.Tensor,
+ dropout_rate: float = DROPOUT_RATE) -> spec.Tensor:
# Patch extraction.
x = self.conv_patch_extract(x)
diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py
index f86a1b1c2..e1c6844fe 100644
--- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py
+++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py
@@ -21,9 +21,7 @@
# Make sure we inherit from the ViT base workload first.
class ImagenetVitWorkload(BaseImagenetVitWorkload, ImagenetResNetWorkload):
- def init_model_fn(
- self,
- rng: spec.RandomState) -> spec.ModelInitState:
+ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
torch.random.manual_seed(rng[0])
model = models.ViT(
num_classes=self._num_classes,
@@ -52,7 +50,8 @@ def model_fn(
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool,
- dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ dropout_rate: float = models.DROPOUT_RATE
+ ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del rng
del update_batch_norm
@@ -71,8 +70,9 @@ def model_fn(
}
with contexts[mode]():
- logits_batch = model(augmented_and_preprocessed_input_batch['inputs'],
- dropout_rate=dropout_rate)
+ logits_batch = model(
+ augmented_and_preprocessed_input_batch['inputs'],
+ dropout_rate=dropout_rate)
return logits_batch, None
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py
index a6a60bf95..10b8e585a 100644
--- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py
+++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py
@@ -73,9 +73,7 @@ def forward(self, x):
class Subsample(nn.Module):
- def __init__(self,
- encoder_dim: int = 0,
- num_bins: int = 80):
+ def __init__(self, encoder_dim: int = 0, num_bins: int = 80):
super().__init__()
self.encoder_dim = encoder_dim
@@ -105,7 +103,8 @@ def forward(self, inputs, input_paddings, dropout_rate):
outputs = self.linear(outputs)
outputs = outputs + self.pos_encode(seq_length=outputs.shape[1])
- outputs = F.dropout(outputs, dropout_rate, training=self.training, inplace=True)
+ outputs = F.dropout(
+ outputs, dropout_rate, training=self.training, inplace=True)
return outputs, output_paddings
@@ -215,7 +214,8 @@ def forward(self, inputs, padding_mask, dropout_rate):
inputs = inputs * padding_mask
inputs = self.linear2(inputs)
inputs = inputs * padding_mask
- inputs = F.dropout(inputs, dropout_rate, training=self.training, inplace=True)
+ inputs = F.dropout(
+ inputs, dropout_rate, training=self.training, inplace=True)
return inputs
@@ -309,7 +309,8 @@ def forward(self, outputs, paddings, dropout_rate):
outputs,
key_padding_mask=paddings == 1,
)
- outputs = F.dropout(outputs, dropout_rate, training=self.training, inplace=True)
+ outputs = F.dropout(
+ outputs, dropout_rate, training=self.training, inplace=True)
return outputs
@@ -412,7 +413,8 @@ def forward(self, inputs, input_paddings, dropout_rate):
inputs = activation_fn(inputs)
inputs = self.lin3(inputs)
- inputs = F.dropout(inputs, dropout_rate, training=self.training, inplace=True)
+ inputs = F.dropout(
+ inputs, dropout_rate, training=self.training, inplace=True)
return inputs
@@ -460,8 +462,7 @@ def __init__(self, config: ConformerConfig):
use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames
)
self.subsample = Subsample(
- encoder_dim=config.encoder_dim,
- num_bins=preprocessing_config.num_bins)
+ encoder_dim=config.encoder_dim, num_bins=preprocessing_config.num_bins)
self.conformers = nn.ModuleList(
[ConformerBlock(config) for _ in range(config.num_encoder_layers)])
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py
index 0477a7389..dbeabb16c 100644
--- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py
+++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py
@@ -61,9 +61,7 @@ def use_gelu(self) -> bool:
def attention_temperature(self) -> float:
return 1.0
- def init_model_fn(
- self,
- rng: spec.RandomState) -> spec.ModelInitState:
+ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
"""Conformer model init function."""
torch.random.manual_seed(rng[0])
# Configure torch backends to avoid OOM errors.
@@ -106,7 +104,8 @@ def model_fn(
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool,
- dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ dropout_rate: float = models.DROPOUT_RATE
+ ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del rng
diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py
index a8480a343..644c13a16 100644
--- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py
+++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py
@@ -367,7 +367,8 @@ def forward(self, inputs, input_paddings, dropout_rate=DROPOUT_RATE):
for idx in range(self.config.num_ffn_layers):
if self.config.enable_residual_connections:
- outputs = outputs + self.ffns[idx](outputs, output_paddings, dropout_rate)
+ outputs = outputs + self.ffns[idx](
+ outputs, output_paddings, dropout_rate)
else:
outputs = self.ffns[idx](outputs, output_paddings, dropout_rate)
diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py
index bf345cfc9..7640d69a5 100644
--- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py
+++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py
@@ -23,9 +23,7 @@
class LibriSpeechDeepSpeechWorkload(LibriSpeechConformerWorkload):
- def init_model_fn(
- self,
- rng: spec.RandomState) -> spec.ModelInitState:
+ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
"""Deepspeech model init function."""
torch.random.manual_seed(rng[0])
model = DeepspeechEncoderDecoder(
@@ -55,7 +53,7 @@ def init_model_fn(
else:
model = torch.nn.DataParallel(model)
return model, None
-
+
def model_fn(
self,
params: spec.ParameterContainer,
@@ -64,11 +62,16 @@ def model_fn(
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool,
- dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ dropout_rate: float = models.DROPOUT_RATE
+ ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
# override super method, changing only the default dropout_rate
- return super().model_fn(
- params, augmented_and_preprocessed_input_batch, model_state,
- mode, rng, update_batch_norm, dropout_rate)
+ return super().model_fn(params,
+ augmented_and_preprocessed_input_batch,
+ model_state,
+ mode,
+ rng,
+ update_batch_norm,
+ dropout_rate)
def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key in ['lin.weight', 'lin.bias']
diff --git a/algoperf/workloads/ogbg/ogbg_pytorch/models.py b/algoperf/workloads/ogbg/ogbg_pytorch/models.py
index be5882333..8a40bef58 100644
--- a/algoperf/workloads/ogbg/ogbg_pytorch/models.py
+++ b/algoperf/workloads/ogbg/ogbg_pytorch/models.py
@@ -9,7 +9,8 @@
from torch import nn
from algoperf import init_utils
-from algoperf.pytorch_utils import CustomDropout, SequentialWithDropout
+from algoperf.pytorch_utils import CustomDropout
+from algoperf.pytorch_utils import SequentialWithDropout
DROPOUT_RATE = 0.1
@@ -93,10 +94,9 @@ def __init__(self,
if isinstance(m, nn.Linear):
init_utils.pytorch_default_init(m)
- def forward(
- self,
- graph: GraphsTuple,
- dropout_rate: float = DROPOUT_RATE) -> torch.Tensor:
+ def forward(self,
+ graph: GraphsTuple,
+ dropout_rate: float = DROPOUT_RATE) -> torch.Tensor:
graph = graph._replace(
globals=torch.zeros([graph.n_node.shape[0], self.num_outputs],
diff --git a/algoperf/workloads/ogbg/ogbg_pytorch/workload.py b/algoperf/workloads/ogbg/ogbg_pytorch/workload.py
index 758b36b60..a45a93668 100644
--- a/algoperf/workloads/ogbg/ogbg_pytorch/workload.py
+++ b/algoperf/workloads/ogbg/ogbg_pytorch/workload.py
@@ -137,9 +137,7 @@ def _build_input_queue(self,
yield batch
- def init_model_fn(
- self,
- rng: spec.RandomState) -> spec.ModelInitState:
+ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
torch.random.manual_seed(rng[0])
model = GNN(
num_outputs=self._num_outputs,
@@ -168,7 +166,8 @@ def model_fn(
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool,
- dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ dropout_rate: float = models.DROPOUT_RATE
+ ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
"""Get predicted logits from the network for input graphs."""
del rng
del update_batch_norm # No BN in the GNN model.
@@ -188,8 +187,9 @@ def model_fn(
}
with contexts[mode]():
- logits = model(augmented_and_preprocessed_input_batch['inputs'],
- dropout_rate=dropout_rate)
+ logits = model(
+ augmented_and_preprocessed_input_batch['inputs'],
+ dropout_rate=dropout_rate)
return logits, None
diff --git a/algoperf/workloads/wmt/wmt_pytorch/models.py b/algoperf/workloads/wmt/wmt_pytorch/models.py
index a43df30d4..0de719c4b 100644
--- a/algoperf/workloads/wmt/wmt_pytorch/models.py
+++ b/algoperf/workloads/wmt/wmt_pytorch/models.py
@@ -226,7 +226,8 @@ def __init__(self,
self.enable_nested_tensor = enable_nested_tensor
self.mask_check = mask_check
- def forward(self, src: Tensor,
+ def forward(self,
+ src: Tensor,
mask: Optional[Tensor] = None,
dropout_rate: Optional[float] = 0.0) -> Tensor:
"""Pass the input through the encoder layers in turn.
@@ -341,7 +342,12 @@ def forward(
if not decode:
tgt = shift_right(tgt)
tgt = self.shared_embedding(tgt)
- tgt = self.pos_encoder(tgt, targets_positions, decode=decode, cache=cache, dropout_rate=dropout_rate)
+ tgt = self.pos_encoder(
+ tgt,
+ targets_positions,
+ decode=decode,
+ cache=cache,
+ dropout_rate=dropout_rate)
if decode:
tgt, cache = tgt
output = self.decoder(
@@ -364,9 +370,7 @@ def forward(
class PositionalEncoding(nn.Module):
- def __init__(self,
- d_model: int,
- max_len: int = 256):
+ def __init__(self, d_model: int, max_len: int = 256):
super().__init__()
position = torch.arange(max_len).unsqueeze(1)
@@ -481,9 +485,9 @@ def __init__(self,
self.activation = activation
- def forward(self,
- src: Tensor,
- src_mask: Optional[Tensor] = None,
+ def forward(self,
+ src: Tensor,
+ src_mask: Optional[Tensor] = None,
dropout_rate: Optional[float] = 0.0) -> Tensor:
r"""Pass the input through the encoder layer.
@@ -505,16 +509,16 @@ def forward(self,
return x
# Self-attention block:
- def _sa_block(self,
- x: Tensor,
- attn_mask: Optional[Tensor],
+ def _sa_block(self,
+ x: Tensor,
+ attn_mask: Optional[Tensor],
dropout_rate: Optional[float] = 0.0) -> Tensor:
x, _ = self.self_attn(x, attn_mask=attn_mask, dropout_rate=dropout_rate)
return F.dropout(x, dropout_rate, training=self.training)
# Feed forward block:
- def _ff_block(self,
- inputs: Tensor,
+ def _ff_block(self,
+ inputs: Tensor,
dropout_rate: Optional[float] = 0.0) -> Tensor:
x = self.activation(self.linear1(inputs))
if self.glu:
@@ -763,18 +767,21 @@ def _sa_block( # pylint: disable=arguments-renamed
return F.dropout(x, dropout_rate, self.training), cache
# Multihead attention block:
- def _mha_block(self, x: Tensor, mem: Tensor,
+ def _mha_block(self,
+ x: Tensor,
+ mem: Tensor,
attn_mask: Optional[Tensor],
dropout_rate: Optional[float] = 0.0) -> Tensor:
x, _ = self.multihead_attn(
- x,
- mem,
- attn_mask=attn_mask,
+ x,
+ mem,
+ attn_mask=attn_mask,
dropout_rate=dropout_rate)
return F.dropout(x, dropout_rate, self.training)
# Feed forward block.
- def _ff_block(self, inputs: Tensor,
+ def _ff_block(self,
+ inputs: Tensor,
dropout_rate: Optional[float] = 0.0) -> Tensor:
x = self.activation(self.linear1(inputs))
if self.glu:
@@ -847,15 +854,17 @@ def _reset_parameters(self):
if module.bias is not None:
normal_(module.bias, std=1e-6)
- def forward(self,
- x: Tensor,
- mem: Optional[Tensor] = None,
- attn_mask: Optional[Tensor] = None,
- decode: bool = False,
- max_len: Optional[int] = None,
- cache: Optional[dict] = None,
- index: Optional[int] = None,
- dropout_rate: Optional[float] = 0.0) -> Any: # TODO: (nico) remove default?!
+ def forward(
+ self,
+ x: Tensor,
+ mem: Optional[Tensor] = None,
+ attn_mask: Optional[Tensor] = None,
+ decode: bool = False,
+ max_len: Optional[int] = None,
+ cache: Optional[dict] = None,
+ index: Optional[int] = None,
+ dropout_rate: Optional[float] = 0.0
+ ) -> Any: # TODO: (nico) remove default?!
r"""
Args:
x: Batch of input sequences of shape
diff --git a/algoperf/workloads/wmt/wmt_pytorch/workload.py b/algoperf/workloads/wmt/wmt_pytorch/workload.py
index 4c787becc..4ec816f2f 100644
--- a/algoperf/workloads/wmt/wmt_pytorch/workload.py
+++ b/algoperf/workloads/wmt/wmt_pytorch/workload.py
@@ -166,9 +166,7 @@ def translate_and_calculate_bleu(self,
bleu_score = bleu.corpus_bleu(predictions, [references]).score
return bleu_score
- def init_model_fn(
- self,
- rng: spec.RandomState) -> spec.ModelInitState:
+ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
torch.random.manual_seed(rng[0])
if self.activation == 'relu':
@@ -204,7 +202,8 @@ def model_fn(
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool,
- dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ dropout_rate: float = models.DROPOUT_RATE
+ ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del rng
del update_batch_norm
diff --git a/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py b/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py
index db56b17cf..f51f1f9aa 100644
--- a/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py
+++ b/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py
@@ -4,20 +4,21 @@
python3 tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py
"""
-from absl.testing import absltest, parameterized
-from torch.testing import assert_close
-import torch
import os
-from algoperf.workloads.criteo1tb.criteo1tb_pytorch.models import (
- DLRMResNet as OriginalDLRMResNet,
- DlrmSmall as OriginalDlrmSmall,
-)
-from algoperf.workloads.criteo1tb.criteo1tb_pytorch.models_dropout import (
- DLRMResNet as CustomDLRMResNet,
- DlrmSmall as CustomDlrmSmall,
-)
+from absl.testing import absltest
+from absl.testing import parameterized
+import torch
+from torch.testing import assert_close
+from algoperf.workloads.criteo1tb.criteo1tb_pytorch.models import \
+ DLRMResNet as OriginalDLRMResNet
+from algoperf.workloads.criteo1tb.criteo1tb_pytorch.models import \
+ DlrmSmall as OriginalDlrmSmall
+from algoperf.workloads.criteo1tb.criteo1tb_pytorch.models_dropout import \
+ DLRMResNet as CustomDLRMResNet
+from algoperf.workloads.criteo1tb.criteo1tb_pytorch.models_dropout import \
+ DlrmSmall as CustomDlrmSmall
BATCH, DENSE, SPARSE = 16, 13, 26
FEATURES = DENSE + SPARSE
diff --git a/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py b/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py
index 0d3d52980..46f1a0f5a 100644
--- a/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py
+++ b/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py
@@ -4,13 +4,17 @@
python3 tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py
"""
-from absl.testing import absltest, parameterized
-from torch.testing import assert_close
-import torch
import os
-from algoperf.workloads.fastmri.fastmri_pytorch.models import UNet as OriginalUNet
-from algoperf.workloads.fastmri.fastmri_pytorch.models_dropout import UNet as CustomUNet
+from absl.testing import absltest
+from absl.testing import parameterized
+import torch
+from torch.testing import assert_close
+
+from algoperf.workloads.fastmri.fastmri_pytorch.models import \
+ UNet as OriginalUNet
+from algoperf.workloads.fastmri.fastmri_pytorch.models_dropout import \
+ UNet as CustomUNet
BATCH, IN_CHANS, H, W = 4, 1, 256, 256
OUT_CHANS, C, LAYERS = 1, 32, 4
diff --git a/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py b/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py
index d19fad0ba..e6d58f5f7 100644
--- a/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py
+++ b/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py
@@ -4,14 +4,18 @@
python3 tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py
"""
-from absl.testing import absltest, parameterized
-from torch.testing import assert_close
-import torch
-import os
import itertools
+import os
+
+from absl.testing import absltest
+from absl.testing import parameterized
+import torch
+from torch.testing import assert_close
-from algoperf.workloads.imagenet_vit.imagenet_pytorch.models import ViT as OriginalVit
-from algoperf.workloads.imagenet_vit.imagenet_pytorch.models_dropout import ViT as CustomVit
+from algoperf.workloads.imagenet_vit.imagenet_pytorch.models import \
+ ViT as OriginalVit
+from algoperf.workloads.imagenet_vit.imagenet_pytorch.models_dropout import \
+ ViT as CustomVit
# Model / test hyper-params
BATCH, C, H, W = 4, 3, 224, 224 # input shape (N,C,H,W)
diff --git a/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py b/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py
index 4a1252a39..d9511fcc4 100644
--- a/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py
+++ b/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py
@@ -6,19 +6,21 @@
NOTE: we don't test for default dropout_rate values, since they changed.
"""
-from absl.testing import absltest, parameterized
-from torch.testing import assert_close
-import torch
import os
-from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import (
- ConformerConfig as OriginalConfig,
- ConformerEncoderDecoder as OriginalModel
-)
-from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models_dropout import(
- ConformerConfig as CustomConfig,
- ConformerEncoderDecoder as CustomModel,
-)
+from absl.testing import absltest
+from absl.testing import parameterized
+import torch
+from torch.testing import assert_close
+
+from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \
+ ConformerConfig as OriginalConfig
+from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \
+ ConformerEncoderDecoder as OriginalModel
+from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models_dropout import \
+ ConformerConfig as CustomConfig
+from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models_dropout import \
+ ConformerEncoderDecoder as CustomModel
N_LAYERS = 3
B, T = 32, 36_000
diff --git a/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py b/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py
index 58ddb354e..610e6b18e 100644
--- a/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py
+++ b/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py
@@ -8,19 +8,21 @@
- `feed_forward_dropout_rate` (if None, 0.1)
"""
-from absl.testing import absltest, parameterized
-from torch.testing import assert_close
-import torch
import os
-from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import (
- DeepspeechEncoderDecoder as OriginalModel,
- DeepspeechConfig as OriginalConfig
-)
-from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models_dropout import(
- DeepspeechEncoderDecoder as CustomModel,
- DeepspeechConfig as CustomConfig
-)
+from absl.testing import absltest
+from absl.testing import parameterized
+import torch
+from torch.testing import assert_close
+
+from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import \
+ DeepspeechConfig as OriginalConfig
+from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import \
+ DeepspeechEncoderDecoder as OriginalModel
+from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models_dropout import \
+ DeepspeechConfig as CustomConfig
+from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models_dropout import \
+ DeepspeechEncoderDecoder as CustomModel
B, T = 32, 30_000
DEVICE = 'cuda'
diff --git a/tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py b/tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py
index aaca6cebd..f5c7e992b 100644
--- a/tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py
+++ b/tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py
@@ -4,13 +4,19 @@
python3 tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py
"""
-from absl.testing import absltest, parameterized
-from torch.testing import assert_close
-import torch, os, random, numpy as np
+import os
+import random
+
+from absl.testing import absltest
+from absl.testing import parameterized
from jraph import GraphsTuple
+import numpy as np
+import torch
+from torch.testing import assert_close
from algoperf.workloads.ogbg.ogbg_pytorch.models import GNN as OriginalModel
-from algoperf.workloads.ogbg.ogbg_pytorch.models_dropout import GNN as CustomModel
+from algoperf.workloads.ogbg.ogbg_pytorch.models_dropout import \
+ GNN as CustomModel
B, N, E = 8, 20, 40 # graphs, nodes/graph, edges/graph
NODE_FDIM, EDGE_FDIM = 9, 3 # expected feature dims
diff --git a/tests/dropout_fix/wmt_pytorch/test_model_equivalence.py b/tests/dropout_fix/wmt_pytorch/test_model_equivalence.py
index 9675f1df2..918043cfd 100644
--- a/tests/dropout_fix/wmt_pytorch/test_model_equivalence.py
+++ b/tests/dropout_fix/wmt_pytorch/test_model_equivalence.py
@@ -4,16 +4,19 @@
python3 tests/dropout_fix/wmt_pytorch/test_model_equivalence.py
"""
-from absl.testing import absltest, parameterized
+import os
+import random
+
+from absl.testing import absltest
+from absl.testing import parameterized
+import numpy as np
+import torch
from torch.testing import assert_close
-import torch, os, random, numpy as np
-
-from algoperf.workloads.wmt.wmt_pytorch.models import (
- Transformer as OriginalModel,
-)
-from algoperf.workloads.wmt.wmt_pytorch.models_dropout import (
- Transformer as CustomModel,
-)
+
+from algoperf.workloads.wmt.wmt_pytorch.models import \
+ Transformer as OriginalModel
+from algoperf.workloads.wmt.wmt_pytorch.models_dropout import \
+ Transformer as CustomModel
B, SRC_LEN, TGT_LEN, NTOK = 16, 80, 80, 32_000
DEVICE = "cuda"
From 4e69255642807cbb38c9d3390898f9713e59ece7 Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Wed, 18 Jun 2025 21:16:42 +0000
Subject: [PATCH 068/123] formatting
---
algoperf/jax_utils.py | 104 +++++++++---------
.../workloads/cifar/cifar_jax/workload.py | 4 +-
.../criteo1tb/criteo1tb_jax/models.py | 1 +
.../criteo1tb/criteo1tb_jax/workload.py | 12 +-
.../workloads/fastmri/fastmri_jax/models.py | 1 +
.../workloads/fastmri/fastmri_jax/workload.py | 16 +--
.../imagenet_resnet/imagenet_jax/workload.py | 3 +-
.../imagenet_vit/imagenet_jax/models.py | 58 +++++-----
.../imagenet_vit/imagenet_jax/workload.py | 12 +-
.../librispeech_jax/models.py | 15 ++-
.../librispeech_jax/workload.py | 3 +-
.../librispeech_jax/models.py | 4 +-
.../librispeech_jax/workload.py | 9 +-
.../workloads/mnist/mnist_jax/workload.py | 4 +-
algoperf/workloads/ogbg/ogbg_jax/models.py | 22 +++-
algoperf/workloads/ogbg/ogbg_jax/workload.py | 7 +-
algoperf/workloads/wmt/wmt_jax/models.py | 20 ++--
algoperf/workloads/wmt/wmt_jax/workload.py | 7 +-
18 files changed, 164 insertions(+), 138 deletions(-)
diff --git a/algoperf/jax_utils.py b/algoperf/jax_utils.py
index c4904dc75..369eb1b1a 100644
--- a/algoperf/jax_utils.py
+++ b/algoperf/jax_utils.py
@@ -1,17 +1,19 @@
from collections.abc import Sequence
-import jax
-import jax.numpy as jnp
-from jax import lax, random
-
import flax.linen as nn
-from flax.linen.module import Module, compact, merge_param
+from flax.linen.module import compact
+from flax.linen.module import merge_param
+from flax.linen.module import Module
from flax.typing import PRNGKey
+import jax
+from jax import lax
+from jax import random
+import jax.numpy as jnp
# Custom Layers
class Dropout(Module):
- """Create a dropout layer.
+ """Create a dropout layer.
Forked from https://flax-linen.readthedocs.io/en/latest/_modules/flax/linen/stochastic.html#Dropout.
The reference dropout implementation is modified support changes to dropout rate during training by:
1) adding rate argument to the __call__ method.
@@ -51,21 +53,21 @@ class Dropout(Module):
rng_collection: the rng collection name to use when requesting an rng key.
"""
- rate: float | None = None
- broadcast_dims: Sequence[int] = ()
- deterministic: bool | None = None
- rng_collection: str = "dropout"
- legacy: bool = True
-
- @compact
- def __call__(
- self,
- inputs,
- deterministic: bool | None = None,
- rate: float | None = None,
- rng: PRNGKey | None = None,
- ):
- """Applies a random dropout mask to the input.
+ rate: float | None = None
+ broadcast_dims: Sequence[int] = ()
+ deterministic: bool | None = None
+ rng_collection: str = "dropout"
+ legacy: bool = True
+
+ @compact
+ def __call__(
+ self,
+ inputs,
+ deterministic: bool | None = None,
+ rate: float | None = None,
+ rng: PRNGKey | None = None,
+ ):
+ """Applies a random dropout mask to the input.
Args:
inputs: the inputs that should be randomly masked.
@@ -79,40 +81,44 @@ def __call__(
Returns:
The masked inputs reweighted to preserve mean.
"""
- deterministic = merge_param("deterministic", self.deterministic, deterministic)
+ deterministic = merge_param("deterministic",
+ self.deterministic,
+ deterministic)
- # Override self.rate if rate is passed to __call__
- if rate is None:
- rate = self.rate
+ # Override self.rate if rate is passed to __call__
+ if rate is None:
+ rate = self.rate
- if self.legacy:
- if rate == 0.0:
- return inputs
+ if self.legacy:
+ if rate == 0.0:
+ return inputs
- # Prevent gradient NaNs in 1.0 edge-case.
- if rate == 1.0:
- return jnp.zeros_like(inputs)
+ # Prevent gradient NaNs in 1.0 edge-case.
+ if rate == 1.0:
+ return jnp.zeros_like(inputs)
- if deterministic:
- return inputs
+ if deterministic:
+ return inputs
- keep_prob = 1.0 - rate
- if rng is None:
- rng = self.make_rng(self.rng_collection)
- broadcast_shape = list(inputs.shape)
- for dim in self.broadcast_dims:
- broadcast_shape[dim] = 1
- mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape)
- mask = jnp.broadcast_to(mask, inputs.shape)
- return lax.select(mask, inputs, jnp.zeros_like(inputs))
+ keep_prob = 1.0 - rate
+ if rng is None:
+ rng = self.make_rng(self.rng_collection)
+ broadcast_shape = list(inputs.shape)
+ for dim in self.broadcast_dims:
+ broadcast_shape[dim] = 1
+ mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape)
+ mask = jnp.broadcast_to(mask, inputs.shape)
+ return lax.select(mask, inputs, jnp.zeros_like(inputs))
# Utilities for debugging
def print_jax_model_summary(model, fake_inputs):
- """Prints a summary of the jax module."""
- tabulate_fn = nn.tabulate(
- model,
- jax.random.PRNGKey(0),
- console_kwargs={"force_terminal": False, "force_jupyter": False, "width": 240},
- )
- print(tabulate_fn(fake_inputs, train=False))
+ """Prints a summary of the jax module."""
+ tabulate_fn = nn.tabulate(
+ model,
+ jax.random.PRNGKey(0),
+ console_kwargs={
+ "force_terminal": False, "force_jupyter": False, "width": 240
+ },
+ )
+ print(tabulate_fn(fake_inputs, train=False))
diff --git a/algoperf/workloads/cifar/cifar_jax/workload.py b/algoperf/workloads/cifar/cifar_jax/workload.py
index 3f2397f8c..c6cc50fbf 100644
--- a/algoperf/workloads/cifar/cifar_jax/workload.py
+++ b/algoperf/workloads/cifar/cifar_jax/workload.py
@@ -79,9 +79,7 @@ def sync_batch_stats(
new_model_state['batch_stats'] = avg_fn(model_state['batch_stats'])
return new_model_state
- def init_model_fn(
- self,
- rng: spec.RandomState) -> spec.ModelInitState:
+ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
"""Dropout is unused."""
model_cls = getattr(models, 'ResNet18')
model = model_cls(num_classes=self._num_classes, dtype=jnp.float32)
diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py
index 57cb7f2d9..4a91a80b8 100644
--- a/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py
+++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py
@@ -9,6 +9,7 @@
DROPOUT_RATE = 0.0
+
class DLRMResNet(nn.Module):
"""Define a DLRMResNet model.
diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py
index cb7e8cf9f..d84d18d5c 100644
--- a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py
+++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py
@@ -89,16 +89,17 @@ def init_model_fn(
use_layer_norm=self.use_layer_norm,
embedding_init_multiplier=self.embedding_init_multiplier)
- params_rng, _= jax.random.split(rng)
+ params_rng, _ = jax.random.split(rng)
init_fake_batch_size = 2
num_categorical_features = 26
num_dense_features = 13
input_size = num_dense_features + num_categorical_features
input_shape = (init_fake_batch_size, input_size)
init_fn = functools.partial(self._model.init, train=False)
- initial_variables = jax.jit(init_fn)(
- {'params': params_rng,},
- jnp.ones(input_shape, jnp.float32))
+ initial_variables = jax.jit(init_fn)({
+ 'params': params_rng,
+ },
+ jnp.ones(input_shape, jnp.float32))
initial_params = initial_variables['params']
self._param_shapes = param_utils.jax_param_shapes(initial_params)
self._param_types = param_utils.jax_param_types(self._param_shapes)
@@ -115,7 +116,8 @@ def model_fn(
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool,
- dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ dropout_rate: float = models.DROPOUT_RATE
+ ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del update_batch_norm
inputs = augmented_and_preprocessed_input_batch['inputs']
diff --git a/algoperf/workloads/fastmri/fastmri_jax/models.py b/algoperf/workloads/fastmri/fastmri_jax/models.py
index 5850defa7..70c7fc4a5 100644
--- a/algoperf/workloads/fastmri/fastmri_jax/models.py
+++ b/algoperf/workloads/fastmri/fastmri_jax/models.py
@@ -23,6 +23,7 @@
DROPOUT_RATE = 0.0
+
def _instance_norm2d(x, axes, epsilon=1e-5):
# promote x to at least float32, this avoids half precision computation
# but preserves double or complex floating points
diff --git a/algoperf/workloads/fastmri/fastmri_jax/workload.py b/algoperf/workloads/fastmri/fastmri_jax/workload.py
index ccf9c6bad..bd0aa1d0b 100644
--- a/algoperf/workloads/fastmri/fastmri_jax/workload.py
+++ b/algoperf/workloads/fastmri/fastmri_jax/workload.py
@@ -30,12 +30,11 @@ def init_model_fn(
num_channels=self.num_channels,
use_tanh=self.use_tanh,
use_layer_norm=self.use_layer_norm,
- )
+ )
params_rng, _ = jax.random.split(rng)
init_fn = functools.partial(self._model.init, train=False)
- variables = jax.jit(init_fn)({'params': params_rng},
- fake_batch)
+ variables = jax.jit(init_fn)({'params': params_rng}, fake_batch)
params = variables['params']
self._param_shapes = param_utils.jax_param_shapes(params)
self._param_types = param_utils.jax_param_types(self._param_shapes)
@@ -53,16 +52,17 @@ def model_fn(
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool,
- dropout_rate: float = DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ dropout_rate: float = DROPOUT_RATE
+ ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del update_batch_norm
train = mode == spec.ForwardPassMode.TRAIN
logits = self._model.apply({'params': params},
- augmented_and_preprocessed_input_batch['inputs'],
- rngs={'dropout': rng},
- train=train,
- dropout_rate=dropout_rate)
+ augmented_and_preprocessed_input_batch['inputs'],
+ rngs={'dropout': rng},
+ train=train,
+ dropout_rate=dropout_rate)
return logits, None
# Does NOT apply regularization, which is left to the submitter to do in
diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py
index 2a255fee4..7896dcd05 100644
--- a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py
+++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py
@@ -84,7 +84,8 @@ def sync_batch_stats(
def init_model_fn(
self,
- rng: spec.RandomState,) -> spec.ModelInitState:
+ rng: spec.RandomState,
+ ) -> spec.ModelInitState:
model_cls = getattr(models, 'ResNet50')
if self.use_silu and self.use_gelu:
diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py
index 716bd4239..f33dea723 100644
--- a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py
+++ b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py
@@ -15,6 +15,7 @@
DROPOUT_RATE = 0.0
+
def posemb_sincos_2d(h: int,
w: int,
width: int,
@@ -121,31 +122,32 @@ def __call__(self,
class Encoder(nn.Module):
- """Transformer Model Encoder for sequence to sequence translation."""
-
- depth: int
- mlp_dim: Optional[int] = None # Defaults to 4x input dim.
- num_heads: int = 12
- use_glu: bool = False
- use_post_layer_norm: bool = False
-
- @nn.compact
- def __call__(
- self, x: spec.Tensor, train: bool = True, dropout_rate: float = DROPOUT_RATE
- ) -> spec.Tensor:
- # Input Encoder
- for lyr in range(self.depth):
- x = Encoder1DBlock(
- name=f"encoderblock_{lyr}",
- mlp_dim=self.mlp_dim,
- num_heads=self.num_heads,
- use_glu=self.use_glu,
- use_post_layer_norm=self.use_post_layer_norm,
- )(x, train=train, dropout_rate=dropout_rate)
- if not self.use_post_layer_norm:
- return nn.LayerNorm(name="encoder_layernorm")(x)
- else:
- return x
+ """Transformer Model Encoder for sequence to sequence translation."""
+
+ depth: int
+ mlp_dim: Optional[int] = None # Defaults to 4x input dim.
+ num_heads: int = 12
+ use_glu: bool = False
+ use_post_layer_norm: bool = False
+
+ @nn.compact
+ def __call__(self,
+ x: spec.Tensor,
+ train: bool = True,
+ dropout_rate: float = DROPOUT_RATE) -> spec.Tensor:
+ # Input Encoder
+ for lyr in range(self.depth):
+ x = Encoder1DBlock(
+ name=f"encoderblock_{lyr}",
+ mlp_dim=self.mlp_dim,
+ num_heads=self.num_heads,
+ use_glu=self.use_glu,
+ use_post_layer_norm=self.use_post_layer_norm,
+ )(x, train=train, dropout_rate=dropout_rate)
+ if not self.use_post_layer_norm:
+ return nn.LayerNorm(name="encoder_layernorm")(x)
+ else:
+ return x
class MAPHead(nn.Module):
@@ -182,7 +184,7 @@ class ViT(nn.Module):
mlp_dim: Optional[int] = None # Defaults to 4x input dim.
num_heads: int = 12
rep_size: Union[int, bool] = True
- dropout_rate: [float] = DROPOUT_RATE
+ dropout_rate: [float] = DROPOUT_RATE
reinit: Optional[Sequence[str]] = None
head_zeroinit: bool = True
use_glu: bool = False
@@ -224,8 +226,8 @@ def __call__(self,
num_heads=self.num_heads,
use_glu=self.use_glu,
use_post_layer_norm=self.use_post_layer_norm,
- name='Transformer',)(
- x, train=not train, dropout_rate=dropout_rate)
+ name='Transformer',
+ )(x, train=not train, dropout_rate=dropout_rate)
if self.use_map:
x = MAPHead(
diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py
index 08a8f4eb1..d0fb4fd72 100644
--- a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py
+++ b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py
@@ -24,15 +24,12 @@ def initialized(self, key: spec.RandomState,
model: nn.Module) -> spec.ModelInitState:
input_shape = (1, 224, 224, 3)
params_rng, _ = jax.random.split(key)
- variables = jax.jit(
- model.init)({'params': params_rng},
- jnp.ones(input_shape))
+ variables = jax.jit(model.init)({'params': params_rng},
+ jnp.ones(input_shape))
model_state, params = pop(variables, "params")
return params, model_state
- def init_model_fn(
- self,
- rng: spec.RandomState) -> spec.ModelInitState:
+ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
self._model = models.ViT(
num_classes=self._num_classes,
use_glu=self.use_glu,
@@ -57,7 +54,8 @@ def model_fn(
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool,
- dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ dropout_rate: float = models.DROPOUT_RATE
+ ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del update_batch_norm
train = mode == spec.ForwardPassMode.TRAIN
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py
index 1c2d79e15..b2eee1c37 100644
--- a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py
+++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py
@@ -22,14 +22,15 @@
import jax.numpy as jnp
import numpy as np
+from algoperf.jax_utils import Dropout
from algoperf.workloads.librispeech_conformer.librispeech_jax import \
librispeech_preprocessor as preprocessor
from algoperf.workloads.librispeech_conformer.librispeech_jax import \
spectrum_augmenter
-from algoperf.jax_utils import Dropout
DROPOUT_RATE = 0.1
+
@struct.dataclass
class ConformerConfig:
"""Global hyperparameters used to minimize obnoxious kwarg plumbing."""
@@ -92,6 +93,7 @@ class Subsample(nn.Module):
input_dropout_rate: dropout rate for inputs.
"""
encoder_dim: int = 0
+
@nn.compact
def __call__(self, inputs, input_paddings, train, dropout_rate=DROPOUT_RATE):
output_paddings = input_paddings
@@ -188,7 +190,11 @@ class FeedForwardModule(nn.Module):
config: ConformerConfig
@nn.compact
- def __call__(self, inputs, padding_mask=None, train=False, dropout_rate=DROPOUT_RATE):
+ def __call__(self,
+ inputs,
+ padding_mask=None,
+ train=False,
+ dropout_rate=DROPOUT_RATE):
config = self.config
inputs = LayerNorm(dim=config.encoder_dim)(inputs)
@@ -218,8 +224,7 @@ def __call__(self, inputs, padding_mask=None, train=False, dropout_rate=DROPOUT_
inputs)
inputs = inputs * padding_mask
- inputs = Dropout(rate=dropout_rate)(
- inputs, deterministic=not train)
+ inputs = Dropout(rate=dropout_rate)(inputs, deterministic=not train)
return inputs
@@ -583,7 +588,7 @@ def __call__(self,
input_paddings,
train,
update_batch_norm,
- use_running_average,
+ use_running_average,
dropout_rate=DROPOUT_RATE):
config = self.config
padding_mask = jnp.expand_dims(1 - input_paddings, -1)
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py
index eec707e5f..1e1a1d3f8 100644
--- a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py
+++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py
@@ -83,8 +83,7 @@ def init_model_fn(
model_init_fn = jax.jit(functools.partial(self._model.init, train=False))
params_rng, _ = jax.random.split(rng, 2)
- variables = model_init_fn({'params': params_rng},
- *fake_input_batch)
+ variables = model_init_fn({'params': params_rng}, *fake_input_batch)
model_state, params = pop(variables, "params")
diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py
index b47b1359a..1bd998027 100644
--- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py
+++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py
@@ -16,11 +16,11 @@
from jax.experimental import rnn
import jax.numpy as jnp
+from algoperf.jax_utils import Dropout
from algoperf.workloads.librispeech_conformer.librispeech_jax import \
librispeech_preprocessor as preprocessor
from algoperf.workloads.librispeech_conformer.librispeech_jax import \
spectrum_augmenter
-from algoperf.jax_utils import Dropout
Array = jnp.ndarray
StateType = Union[Array, Tuple[Array, ...]]
@@ -31,7 +31,7 @@
CarryHistory = Any
Output = Any
-DROPOUT_RATE=0.1
+DROPOUT_RATE = 0.1
@struct.dataclass
diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py
index 2bb119439..81a56db72 100644
--- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py
+++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py
@@ -15,9 +15,7 @@
class LibriSpeechDeepSpeechWorkload(LibriSpeechConformerWorkload):
- def init_model_fn(
- self,
- rng: spec.RandomState) -> spec.ModelInitState:
+ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
"""Deepspeech model init function.
"""
model_config = models.DeepspeechConfig(
@@ -36,8 +34,9 @@ def init_model_fn(
model_init_fn = jax.jit(functools.partial(self._model.init, train=False))
params_rng, _ = jax.random.split(rng, 2)
- variables = model_init_fn({'params': params_rng,},
- *fake_input_batch)
+ variables = model_init_fn({
+ 'params': params_rng,
+ }, *fake_input_batch)
model_state = variables[
'batch_stats'] if not self.layernorm_everywhere else {}
diff --git a/algoperf/workloads/mnist/mnist_jax/workload.py b/algoperf/workloads/mnist/mnist_jax/workload.py
index 5f3fdcf78..27bd9ae54 100644
--- a/algoperf/workloads/mnist/mnist_jax/workload.py
+++ b/algoperf/workloads/mnist/mnist_jax/workload.py
@@ -32,9 +32,7 @@ def __call__(self, x: spec.Tensor, train: bool) -> spec.Tensor:
class MnistWorkload(BaseMnistWorkload):
- def init_model_fn(
- self,
- rng: spec.RandomState) -> spec.ModelInitState:
+ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
init_val = jnp.ones((1, 28, 28, 1), jnp.float32)
self._model = _Model()
initial_params = self._model.init({'params': rng}, init_val,
diff --git a/algoperf/workloads/ogbg/ogbg_jax/models.py b/algoperf/workloads/ogbg/ogbg_jax/models.py
index d51ca2f20..06eef6187 100644
--- a/algoperf/workloads/ogbg/ogbg_jax/models.py
+++ b/algoperf/workloads/ogbg/ogbg_jax/models.py
@@ -8,7 +8,8 @@
from algoperf.jax_utils import Dropout
-DROPOUT_RATE=0.1
+DROPOUT_RATE = 0.1
+
def _make_embed(latent_dim, name):
@@ -28,7 +29,9 @@ def make_fn(inputs):
x = nn.Dense(features=dim)(x)
x = nn.LayerNorm()(x)
x = activation_fn(x)
- x = Dropout(rate=dropout_rate, deterministic=not train)(x, rate=dropout_rate)
+ x = Dropout(
+ rate=dropout_rate, deterministic=not train)(
+ x, rate=dropout_rate)
return x
return make_fn
@@ -69,11 +72,20 @@ def __call__(self, graph, train, dropout_rate=DROPOUT_RATE):
for _ in range(self.num_message_passing_steps):
net = jraph.GraphNetwork(
update_edge_fn=_make_mlp(
- self.hidden_dims, activation_fn=activation_fn, train=train, dropout_rate=dropout_rate),
+ self.hidden_dims,
+ activation_fn=activation_fn,
+ train=train,
+ dropout_rate=dropout_rate),
update_node_fn=_make_mlp(
- self.hidden_dims, activation_fn=activation_fn, train=train, dropout_rate=dropout_rate),
+ self.hidden_dims,
+ activation_fn=activation_fn,
+ train=train,
+ dropout_rate=dropout_rate),
update_global_fn=_make_mlp(
- self.hidden_dims, activation_fn=activation_fn, train=train, dropout_rate=dropout_rate))
+ self.hidden_dims,
+ activation_fn=activation_fn,
+ train=train,
+ dropout_rate=dropout_rate))
graph = net(graph)
diff --git a/algoperf/workloads/ogbg/ogbg_jax/workload.py b/algoperf/workloads/ogbg/ogbg_jax/workload.py
index e03252ed9..04a9bce2e 100644
--- a/algoperf/workloads/ogbg/ogbg_jax/workload.py
+++ b/algoperf/workloads/ogbg/ogbg_jax/workload.py
@@ -17,9 +17,7 @@
class OgbgWorkload(BaseOgbgWorkload):
- def init_model_fn(
- self,
- rng: spec.RandomState) -> spec.ModelInitState:
+ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
rng, params_rng = jax.random.split(rng, 2)
self._model = models.GNN(
self._num_outputs,
@@ -53,7 +51,8 @@ def model_fn(
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool,
- dropout_rate: float = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ dropout_rate: float = models.DROPOUT_RATE
+ ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
"""Get predicted logits from the network for input graphs."""
del update_batch_norm # No BN in the GNN model.
if model_state is not None:
diff --git a/algoperf/workloads/wmt/wmt_jax/models.py b/algoperf/workloads/wmt/wmt_jax/models.py
index 38f76db80..1147eb34b 100644
--- a/algoperf/workloads/wmt/wmt_jax/models.py
+++ b/algoperf/workloads/wmt/wmt_jax/models.py
@@ -13,9 +13,9 @@
from algoperf.jax_utils import Dropout
-
DROPOUT_RATE = 0.1
+
@struct.dataclass
class TransformerConfig:
"""Global hyperparameters used to minimize obnoxious kwarg plumbing."""
@@ -171,7 +171,8 @@ def __call__(self, inputs, dropout_rate=DROPOUT_RATE):
)(
inputs)
x = x * y
- x = Dropout(rate=dropout_rate)(x, rate=dropout_rate, deterministic=cfg.deterministic)
+ x = Dropout(rate=dropout_rate)(
+ x, rate=dropout_rate, deterministic=cfg.deterministic)
output = nn.Dense(
actual_out_dim,
dtype=cfg.dtype,
@@ -223,7 +224,8 @@ def __call__(self, inputs, encoder_mask=None, dropout_rate=DROPOUT_RATE):
deterministic=cfg.deterministic,
)(cfg.attention_temp * x, x, mask=encoder_mask)
- x = Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic, rate=dropout_rate)
+ x = Dropout(rate=dropout_rate)(
+ x, deterministic=cfg.deterministic, rate=dropout_rate)
x = x + inputs
if not pre_ln:
x = nn.LayerNorm(dtype=cfg.dtype)(x)
@@ -285,7 +287,8 @@ def __call__(
decode=cfg.decode,
)(cfg.attention_temp * x, x, mask=decoder_mask)
- x = Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic, rate=dropout_rate)
+ x = Dropout(rate=dropout_rate)(
+ x, deterministic=cfg.deterministic, rate=dropout_rate)
x = x + targets
if not pre_ln:
x = nn.LayerNorm(dtype=cfg.dtype)(x)
@@ -304,7 +307,8 @@ def __call__(
deterministic=cfg.deterministic,
)(cfg.attention_temp * y, encoded, mask=encoder_decoder_mask)
- y = Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic, rate=dropout_rate)
+ y = Dropout(rate=dropout_rate)(
+ y, deterministic=cfg.deterministic, rate=dropout_rate)
y = y + x
if not pre_ln:
y = nn.LayerNorm(dtype=cfg.dtype)(y)
@@ -361,7 +365,8 @@ def __call__(self,
x = AddPositionEmbs(
config=cfg, decode=False, name="posembed_input")(
x, inputs_positions=inputs_positions)
- x = Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic, rate=dropout_rate)
+ x = Dropout(rate=dropout_rate)(
+ x, deterministic=cfg.deterministic, rate=dropout_rate)
x = x.astype(cfg.dtype)
@@ -432,7 +437,8 @@ def __call__(
y = AddPositionEmbs(
config=cfg, decode=cfg.decode, name="posembed_output")(
y, inputs_positions=targets_positions)
- y = Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic, rate=dropout_rate)
+ y = Dropout(rate=dropout_rate)(
+ y, deterministic=cfg.deterministic, rate=dropout_rate)
y = y.astype(cfg.dtype)
diff --git a/algoperf/workloads/wmt/wmt_jax/workload.py b/algoperf/workloads/wmt/wmt_jax/workload.py
index 24d4852b8..d402f9d95 100644
--- a/algoperf/workloads/wmt/wmt_jax/workload.py
+++ b/algoperf/workloads/wmt/wmt_jax/workload.py
@@ -206,9 +206,7 @@ def translate_and_calculate_bleu(self,
bleu_score = bleu.corpus_bleu(predictions, [references]).score
return bleu_score
- def init_model_fn(
- self,
- rng: spec.RandomState) -> spec.ModelInitState:
+ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
init_fake_batch_size = 2
input_shape = (init_fake_batch_size, 256)
target_shape = (init_fake_batch_size, 256)
@@ -250,7 +248,8 @@ def model_fn(
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool,
- dropout_rate: [float] = models.DROPOUT_RATE) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ dropout_rate: [float] = models.DROPOUT_RATE
+ ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del update_batch_norm
From 3ac97ae017defbed75aa8eedd4c1ab09b759e4c1 Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Wed, 18 Jun 2025 21:20:02 +0000
Subject: [PATCH 069/123] fix formatting
---
.../test_model_equivalence.py | 125 ++++++-----
.../fastmri_pytorch/test_model_equivalence.py | 177 +++++++++-------
.../test_model_equivalence.py | 199 +++++++++---------
.../test_model_equivalence.py | 81 ++++---
.../test_model_equivalence.py | 145 +++++++------
.../ogbg_pytorch/test_model_equivalence.py | 123 ++++++-----
.../wmt_pytorch/test_model_equivalence.py | 160 +++++++-------
7 files changed, 540 insertions(+), 470 deletions(-)
diff --git a/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py b/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py
index f51f1f9aa..733052dd0 100644
--- a/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py
+++ b/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py
@@ -31,63 +31,90 @@
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms(True)
+
class ModelEquivalenceTest(parameterized.TestCase):
- @parameterized.named_parameters(
- dict(testcase_name='DLRMResNet, p=0.0', model='dlrm_resnet', dropout_rate=0.0),
- dict(testcase_name='DlrmSmall, p=0.0', model='dlrm_small', dropout_rate=0.0),
- dict(testcase_name='DLRMResNet, p=0.1', model='dlrm_resnet', dropout_rate=0.1),
- dict(testcase_name='DlrmSmall, p=0.1', model='dlrm_small', dropout_rate=0.1),
- dict(testcase_name='DLRMResNet, p=1.0', model='dlrm_resnet', dropout_rate=1.0),
- dict(testcase_name='DlrmSmall, p=1.0', model='dlrm_small', dropout_rate=1.0),
+ @parameterized.named_parameters(
+ dict(
+ testcase_name='DLRMResNet, p=0.0',
+ model='dlrm_resnet',
+ dropout_rate=0.0),
+ dict(
+ testcase_name='DlrmSmall, p=0.0',
+ model='dlrm_small',
+ dropout_rate=0.0),
+ dict(
+ testcase_name='DLRMResNet, p=0.1',
+ model='dlrm_resnet',
+ dropout_rate=0.1),
+ dict(
+ testcase_name='DlrmSmall, p=0.1',
+ model='dlrm_small',
+ dropout_rate=0.1),
+ dict(
+ testcase_name='DLRMResNet, p=1.0',
+ model='dlrm_resnet',
+ dropout_rate=1.0),
+ dict(
+ testcase_name='DlrmSmall, p=1.0',
+ model='dlrm_small',
+ dropout_rate=1.0),
+ )
+ def test_forward(self, model, dropout_rate):
+ OrigCls, CustCls = (
+ (OriginalDLRMResNet, CustomDLRMResNet)
+ if model == 'dlrm_resnet'
+ else (OriginalDlrmSmall, CustomDlrmSmall)
)
- def test_forward(self, model, dropout_rate):
- OrigCls, CustCls = (
- (OriginalDLRMResNet, CustomDLRMResNet)
- if model == 'dlrm_resnet'
- else (OriginalDlrmSmall, CustomDlrmSmall)
- )
- torch.manual_seed(SEED)
- orig = OrigCls(vocab_size=VOCAB, dropout_rate=dropout_rate).to(DEVICE)
+ torch.manual_seed(SEED)
+ orig = OrigCls(vocab_size=VOCAB, dropout_rate=dropout_rate).to(DEVICE)
+
+ torch.manual_seed(SEED)
+ cust = CustCls(vocab_size=VOCAB).to(DEVICE)
+
+ x = torch.randn(BATCH, FEATURES, device=DEVICE)
+ for mode in ('train', 'eval'):
+ getattr(orig, mode)()
+ getattr(cust, mode)()
+ torch.manual_seed(SEED)
+ y1 = orig(x)
+ torch.manual_seed(SEED)
+ y2 = cust(x, dropout_rate)
+ assert_close(y1, y2, atol=0, rtol=0)
+ if mode == 'eval': # one extra test: omit dropout at eval
torch.manual_seed(SEED)
- cust = CustCls(vocab_size=VOCAB).to(DEVICE)
-
- x = torch.randn(BATCH, FEATURES, device=DEVICE)
-
- for mode in ('train', 'eval'):
- getattr(orig, mode)(); getattr(cust, mode)()
- torch.manual_seed(SEED); y1 = orig(x)
- torch.manual_seed(SEED); y2 = cust(x, dropout_rate)
- assert_close(y1, y2, atol=0, rtol=0)
- if mode == 'eval': # one extra test: omit dropout at eval
- torch.manual_seed(SEED); y2 = cust(x)
- assert_close(y1, y2, atol=0, rtol=0)
-
- @parameterized.named_parameters(
- dict(testcase_name='DLRMResNet, default', model='dlrm_resnet'),
- dict(testcase_name='DlrmSmall, default', model='dlrm_small'),
+ y2 = cust(x)
+ assert_close(y1, y2, atol=0, rtol=0)
+
+ @parameterized.named_parameters(
+ dict(testcase_name='DLRMResNet, default', model='dlrm_resnet'),
+ dict(testcase_name='DlrmSmall, default', model='dlrm_small'),
+ )
+ def test_default_dropout(self, model):
+ """Test default dropout_rate."""
+ OrigCls, CustCls = (
+ (OriginalDLRMResNet, CustomDLRMResNet)
+ if model == 'dlrm_resnet'
+ else (OriginalDlrmSmall, CustomDlrmSmall)
)
- def test_default_dropout(self, model):
- """Test default dropout_rate."""
- OrigCls, CustCls = (
- (OriginalDLRMResNet, CustomDLRMResNet)
- if model == 'dlrm_resnet'
- else (OriginalDlrmSmall, CustomDlrmSmall)
- )
- torch.manual_seed(SEED)
- orig = OrigCls(vocab_size=VOCAB).to(DEVICE)
- torch.manual_seed(SEED)
- cust = CustCls(vocab_size=VOCAB).to(DEVICE)
+ torch.manual_seed(SEED)
+ orig = OrigCls(vocab_size=VOCAB).to(DEVICE)
+ torch.manual_seed(SEED)
+ cust = CustCls(vocab_size=VOCAB).to(DEVICE)
+
+ x = torch.randn(BATCH, FEATURES, device=DEVICE)
+ for mode in ('train', 'eval'):
+ getattr(orig, mode)()
+ getattr(cust, mode)()
+ torch.manual_seed(0)
+ y1 = orig(x)
+ torch.manual_seed(0)
+ y2 = cust(x)
+ assert_close(y1, y2, atol=0, rtol=0)
- x = torch.randn(BATCH, FEATURES, device=DEVICE)
- for mode in ('train', 'eval'):
- getattr(orig, mode)(); getattr(cust, mode)()
- torch.manual_seed(0); y1 = orig(x)
- torch.manual_seed(0); y2 = cust(x)
- assert_close(y1, y2, atol=0, rtol=0)
if __name__ == '__main__':
- absltest.main()
+ absltest.main()
diff --git a/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py b/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py
index 46f1a0f5a..1d318e8c6 100644
--- a/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py
+++ b/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py
@@ -27,87 +27,104 @@
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms(True)
+
class FastMRIModeEquivalenceTest(parameterized.TestCase):
- def fwd_pass(self, orig, cust, dropout_rate):
- x = torch.randn(BATCH, IN_CHANS, H, W, device=DEVICE)
- for mode in ('train', 'eval'):
- getattr(orig, mode)(); getattr(cust, mode)()
- torch.manual_seed(0); y1 = orig(x)
- torch.manual_seed(0); y2 = cust(x, dropout_rate)
- assert_close(y1, y2, atol=0, rtol=0)
- if mode == 'eval': # one extra test: omit dropout at eval
- torch.manual_seed(0); y2 = cust(x)
- assert_close(y1, y2, atol=0, rtol=0)
-
- @parameterized.named_parameters(
- dict(testcase_name='p=0.0', dropout_rate=0.0),
- dict(testcase_name='p=0.1', dropout_rate=0.1),
- dict(testcase_name='p=0.7', dropout_rate=0.7),
- dict(testcase_name='p=1.0', dropout_rate=1.0),
- )
- def test_dropout_values(self, dropout_rate):
- """Test different values of dropout_rate."""
-
- torch.manual_seed(SEED)
- orig = OriginalUNet(IN_CHANS, OUT_CHANS, C, LAYERS, dropout_rate=dropout_rate).to(DEVICE)
-
- torch.manual_seed(SEED)
- cust = CustomUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE)
-
- cust.load_state_dict(orig.state_dict()) # sync weights
- if TORCH_COMPILE:
- orig = torch.compile(orig); cust = torch.compile(cust)
-
- self.fwd_pass(orig, cust, dropout_rate)
-
-
- @parameterized.named_parameters(
- dict(testcase_name='default', use_tanh=False, use_layer_norm=False),
- dict(testcase_name='tanh', use_tanh=True, use_layer_norm=False),
- dict(testcase_name='layer_norm', use_tanh=False, use_layer_norm=True),
- dict(testcase_name='both', use_tanh=True, use_layer_norm=True),
- )
- def test_arch_configs(self, use_tanh, use_layer_norm):
- """Test different architecture configurations, fixed dropout_rate."""
- dropout_rate = 0.1
-
- torch.manual_seed(SEED)
- orig = OriginalUNet(
- IN_CHANS, OUT_CHANS, C, LAYERS, dropout_rate=dropout_rate,
- use_tanh=use_tanh, use_layer_norm=use_layer_norm
- ).to(DEVICE)
-
- torch.manual_seed(SEED)
- cust = CustomUNet(
- IN_CHANS, OUT_CHANS, C, LAYERS,
- use_tanh=use_tanh, use_layer_norm=use_layer_norm
- ).to(DEVICE)
-
- cust.load_state_dict(orig.state_dict()) # sync weights
- if TORCH_COMPILE:
- orig = torch.compile(orig); cust = torch.compile(cust)
-
- self.fwd_pass(orig, cust, dropout_rate)
-
- @parameterized.named_parameters(
- dict(testcase_name=''),
- )
- def test_default_dropout(self):
- """Test default dropout_rate."""
-
- torch.manual_seed(SEED)
- orig = OriginalUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE)
- torch.manual_seed(SEED)
- cust = CustomUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE)
- cust.load_state_dict(orig.state_dict()) # sync weights
-
- x = torch.randn(BATCH, IN_CHANS, H, W, device=DEVICE)
- for mode in ('train', 'eval'):
- getattr(orig, mode)(); getattr(cust, mode)()
- torch.manual_seed(0); y1 = orig(x)
- torch.manual_seed(0); y2 = cust(x)
- assert_close(y1, y2, atol=0, rtol=0)
+ def fwd_pass(self, orig, cust, dropout_rate):
+ x = torch.randn(BATCH, IN_CHANS, H, W, device=DEVICE)
+ for mode in ('train', 'eval'):
+ getattr(orig, mode)()
+ getattr(cust, mode)()
+ torch.manual_seed(0)
+ y1 = orig(x)
+ torch.manual_seed(0)
+ y2 = cust(x, dropout_rate)
+ assert_close(y1, y2, atol=0, rtol=0)
+ if mode == 'eval': # one extra test: omit dropout at eval
+ torch.manual_seed(0)
+ y2 = cust(x)
+ assert_close(y1, y2, atol=0, rtol=0)
+
+ @parameterized.named_parameters(
+ dict(testcase_name='p=0.0', dropout_rate=0.0),
+ dict(testcase_name='p=0.1', dropout_rate=0.1),
+ dict(testcase_name='p=0.7', dropout_rate=0.7),
+ dict(testcase_name='p=1.0', dropout_rate=1.0),
+ )
+ def test_dropout_values(self, dropout_rate):
+ """Test different values of dropout_rate."""
+
+ torch.manual_seed(SEED)
+ orig = OriginalUNet(
+ IN_CHANS, OUT_CHANS, C, LAYERS, dropout_rate=dropout_rate).to(DEVICE)
+
+ torch.manual_seed(SEED)
+ cust = CustomUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE)
+
+ cust.load_state_dict(orig.state_dict()) # sync weights
+ if TORCH_COMPILE:
+ orig = torch.compile(orig)
+ cust = torch.compile(cust)
+
+ self.fwd_pass(orig, cust, dropout_rate)
+
+ @parameterized.named_parameters(
+ dict(testcase_name='default', use_tanh=False, use_layer_norm=False),
+ dict(testcase_name='tanh', use_tanh=True, use_layer_norm=False),
+ dict(testcase_name='layer_norm', use_tanh=False, use_layer_norm=True),
+ dict(testcase_name='both', use_tanh=True, use_layer_norm=True),
+ )
+ def test_arch_configs(self, use_tanh, use_layer_norm):
+ """Test different architecture configurations, fixed dropout_rate."""
+ dropout_rate = 0.1
+
+ torch.manual_seed(SEED)
+ orig = OriginalUNet(
+ IN_CHANS,
+ OUT_CHANS,
+ C,
+ LAYERS,
+ dropout_rate=dropout_rate,
+ use_tanh=use_tanh,
+ use_layer_norm=use_layer_norm).to(DEVICE)
+
+ torch.manual_seed(SEED)
+ cust = CustomUNet(
+ IN_CHANS,
+ OUT_CHANS,
+ C,
+ LAYERS,
+ use_tanh=use_tanh,
+ use_layer_norm=use_layer_norm).to(DEVICE)
+
+ cust.load_state_dict(orig.state_dict()) # sync weights
+ if TORCH_COMPILE:
+ orig = torch.compile(orig)
+ cust = torch.compile(cust)
+
+ self.fwd_pass(orig, cust, dropout_rate)
+
+ @parameterized.named_parameters(
+ dict(testcase_name=''),)
+ def test_default_dropout(self):
+ """Test default dropout_rate."""
+
+ torch.manual_seed(SEED)
+ orig = OriginalUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE)
+ torch.manual_seed(SEED)
+ cust = CustomUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE)
+ cust.load_state_dict(orig.state_dict()) # sync weights
+
+ x = torch.randn(BATCH, IN_CHANS, H, W, device=DEVICE)
+ for mode in ('train', 'eval'):
+ getattr(orig, mode)()
+ getattr(cust, mode)()
+ torch.manual_seed(0)
+ y1 = orig(x)
+ torch.manual_seed(0)
+ y2 = cust(x)
+ assert_close(y1, y2, atol=0, rtol=0)
+
if __name__ == '__main__':
- absltest.main()
+ absltest.main()
diff --git a/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py b/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py
index e6d58f5f7..f51eaec7e 100644
--- a/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py
+++ b/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py
@@ -18,7 +18,7 @@
ViT as CustomVit
# Model / test hyper-params
-BATCH, C, H, W = 4, 3, 224, 224 # input shape (N,C,H,W)
+BATCH, C, H, W = 4, 3, 224, 224 # input shape (N,C,H,W)
WIDTH, DEPTH, HEADS = 256, 4, 8
DROPOUT_RATE = None
DEVICE = 'cuda'
@@ -29,103 +29,110 @@
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms(True)
+
class ImageNetVitModeEquivalenceTest(parameterized.TestCase):
- def fwd_pass(self, orig, cust, dropout_rate):
- x = torch.randn(BATCH, C, H, W, device=DEVICE)
- for mode in ('train', 'eval'):
- getattr(orig, mode)(); getattr(cust, mode)()
- torch.manual_seed(0); y1 = orig(x)
- torch.manual_seed(0); y2 = cust(x, dropout_rate)
- assert_close(y1, y2, atol=0, rtol=0)
- if mode == 'eval': # one extra test: omit dropout at eval
- torch.manual_seed(0); y2 = cust(x, dropout_rate)
- assert_close(y1, y2, atol=0, rtol=0)
-
- @parameterized.named_parameters(
- dict(testcase_name='p=0.0', dropout_rate=0.0),
- dict(testcase_name='p=0.1', dropout_rate=0.1),
- dict(testcase_name='p=0.6', dropout_rate=0.6),
- dict(testcase_name='p=1.0', dropout_rate=1.0),
- )
- def test_dropout_values(self, dropout_rate):
- """Test different dropout_rates."""
-
- torch.manual_seed(SEED)
- orig = OriginalVit(
- width=WIDTH,
- depth=DEPTH,
- num_heads=HEADS,
- dropout_rate=dropout_rate,
- ).to(DEVICE)
-
- torch.manual_seed(SEED)
- cust = CustomVit(
- width=WIDTH,
- depth=DEPTH,
- num_heads=HEADS,
- ).to(DEVICE)
-
- cust.load_state_dict(orig.state_dict()) # sync weights
- self.fwd_pass(orig, cust, dropout_rate)
-
-
- @parameterized.named_parameters([
- dict(
- testcase_name=f"GLU={use_glu}_LN={use_post_ln}_MAP={use_map}",
- use_glu=use_glu,
- use_post_ln=use_post_ln,
- use_map=use_map,
- )
- for use_glu, use_post_ln, use_map in itertools.product([False, True], repeat=3)
- ])
- def test_arch(self, use_glu, use_post_ln, use_map):
- """Test different architecture configurations, fixed dropout_rate."""
- dropout_rate = 0.1
-
- torch.manual_seed(SEED)
- orig = OriginalVit(
- width=WIDTH,
- depth=DEPTH,
- num_heads=HEADS,
- use_glu=use_glu,
- use_post_layer_norm=use_post_ln,
- use_map=use_map,
- dropout_rate=dropout_rate,
- ).to(DEVICE)
-
- torch.manual_seed(SEED)
- cust = CustomVit(
- width=WIDTH,
- depth=DEPTH,
- num_heads=HEADS,
- use_glu=use_glu,
- use_post_layer_norm=use_post_ln,
- use_map=use_map,
- ).to(DEVICE)
-
- cust.load_state_dict(orig.state_dict()) # sync weights
- self.fwd_pass(orig, cust, dropout_rate)
-
- @parameterized.named_parameters(
- dict(testcase_name=''),
- )
- def test_default_dropout(self):
- """Test default dropout_rate."""
-
- torch.manual_seed(SEED)
- orig = OriginalVit(width=WIDTH, depth=DEPTH, num_heads=HEADS).to(DEVICE)
- torch.manual_seed(SEED)
- cust = CustomVit(width=WIDTH, depth=DEPTH, num_heads=HEADS).to(DEVICE)
- cust.load_state_dict(orig.state_dict()) # sync weights
-
- x = torch.randn(BATCH, C, H, W, device=DEVICE)
- for mode in ('train', 'eval'):
- getattr(orig, mode)(); getattr(cust, mode)()
- torch.manual_seed(0); y1 = orig(x)
- torch.manual_seed(0); y2 = cust(x)
- assert_close(y1, y2, atol=0, rtol=0)
+ def fwd_pass(self, orig, cust, dropout_rate):
+ x = torch.randn(BATCH, C, H, W, device=DEVICE)
+ for mode in ('train', 'eval'):
+ getattr(orig, mode)()
+ getattr(cust, mode)()
+ torch.manual_seed(0)
+ y1 = orig(x)
+ torch.manual_seed(0)
+ y2 = cust(x, dropout_rate)
+ assert_close(y1, y2, atol=0, rtol=0)
+ if mode == 'eval': # one extra test: omit dropout at eval
+ torch.manual_seed(0)
+ y2 = cust(x, dropout_rate)
+ assert_close(y1, y2, atol=0, rtol=0)
+
+ @parameterized.named_parameters(
+ dict(testcase_name='p=0.0', dropout_rate=0.0),
+ dict(testcase_name='p=0.1', dropout_rate=0.1),
+ dict(testcase_name='p=0.6', dropout_rate=0.6),
+ dict(testcase_name='p=1.0', dropout_rate=1.0),
+ )
+ def test_dropout_values(self, dropout_rate):
+ """Test different dropout_rates."""
+
+ torch.manual_seed(SEED)
+ orig = OriginalVit(
+ width=WIDTH,
+ depth=DEPTH,
+ num_heads=HEADS,
+ dropout_rate=dropout_rate,
+ ).to(DEVICE)
+
+ torch.manual_seed(SEED)
+ cust = CustomVit(
+ width=WIDTH,
+ depth=DEPTH,
+ num_heads=HEADS,
+ ).to(DEVICE)
+
+ cust.load_state_dict(orig.state_dict()) # sync weights
+ self.fwd_pass(orig, cust, dropout_rate)
+
+ @parameterized.named_parameters([
+ dict(
+ testcase_name=f"GLU={use_glu}_LN={use_post_ln}_MAP={use_map}",
+ use_glu=use_glu,
+ use_post_ln=use_post_ln,
+ use_map=use_map,
+ ) for use_glu,
+ use_post_ln,
+ use_map in itertools.product([False, True], repeat=3)
+ ])
+ def test_arch(self, use_glu, use_post_ln, use_map):
+ """Test different architecture configurations, fixed dropout_rate."""
+ dropout_rate = 0.1
+
+ torch.manual_seed(SEED)
+ orig = OriginalVit(
+ width=WIDTH,
+ depth=DEPTH,
+ num_heads=HEADS,
+ use_glu=use_glu,
+ use_post_layer_norm=use_post_ln,
+ use_map=use_map,
+ dropout_rate=dropout_rate,
+ ).to(DEVICE)
+
+ torch.manual_seed(SEED)
+ cust = CustomVit(
+ width=WIDTH,
+ depth=DEPTH,
+ num_heads=HEADS,
+ use_glu=use_glu,
+ use_post_layer_norm=use_post_ln,
+ use_map=use_map,
+ ).to(DEVICE)
+
+ cust.load_state_dict(orig.state_dict()) # sync weights
+ self.fwd_pass(orig, cust, dropout_rate)
+
+ @parameterized.named_parameters(
+ dict(testcase_name=''),)
+ def test_default_dropout(self):
+ """Test default dropout_rate."""
+
+ torch.manual_seed(SEED)
+ orig = OriginalVit(width=WIDTH, depth=DEPTH, num_heads=HEADS).to(DEVICE)
+ torch.manual_seed(SEED)
+ cust = CustomVit(width=WIDTH, depth=DEPTH, num_heads=HEADS).to(DEVICE)
+ cust.load_state_dict(orig.state_dict()) # sync weights
+
+ x = torch.randn(BATCH, C, H, W, device=DEVICE)
+ for mode in ('train', 'eval'):
+ getattr(orig, mode)()
+ getattr(cust, mode)()
+ torch.manual_seed(0)
+ y1 = orig(x)
+ torch.manual_seed(0)
+ y2 = cust(x)
+ assert_close(y1, y2, atol=0, rtol=0)
if __name__ == '__main__':
- absltest.main()
+ absltest.main()
diff --git a/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py b/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py
index d9511fcc4..02f3a3d84 100644
--- a/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py
+++ b/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py
@@ -26,7 +26,7 @@
B, T = 32, 36_000
DEVICE = 'cuda'
-os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:8"
+os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms(mode=True)
@@ -35,53 +35,50 @@
class ConformerEquivalenceTest(parameterized.TestCase):
- @parameterized.named_parameters(
- dict(testcase_name='p=0.0', dropout_rate=0.0),
- dict(testcase_name='p=0.2', dropout_rate=0.2),
- dict(testcase_name='p=0.7', dropout_rate=0.7),
- dict(testcase_name='p=1.0', dropout_rate=1.0),
- )
- def test_forward(self, dropout_rate):
-
- torch.manual_seed(SEED)
- orig = OriginalModel(
- OriginalConfig(
+ @parameterized.named_parameters(
+ dict(testcase_name='p=0.0', dropout_rate=0.0),
+ dict(testcase_name='p=0.2', dropout_rate=0.2),
+ dict(testcase_name='p=0.7', dropout_rate=0.7),
+ dict(testcase_name='p=1.0', dropout_rate=1.0),
+ )
+ def test_forward(self, dropout_rate):
+
+ torch.manual_seed(SEED)
+ orig = OriginalModel(
+ OriginalConfig(
num_encoder_layers=N_LAYERS,
attention_residual_dropout_rate=dropout_rate,
conv_residual_dropout_rate=dropout_rate,
feed_forward_residual_dropout_rate=dropout_rate,
input_dropout_rate=dropout_rate,
)).to(DEVICE)
-
+
+ torch.manual_seed(SEED)
+ cust = CustomModel(CustomConfig(num_encoder_layers=N_LAYERS)).to(DEVICE)
+
+ orig.load_state_dict(cust.state_dict()) # sync weights
+
+ x = torch.randn(B, T, device=DEVICE)
+ paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE)
+
+ for mode in ('train', 'eval'):
+ getattr(orig, mode)()
+ getattr(cust, mode)()
+
+ torch.manual_seed(SEED)
+ y1, p1 = orig(x, paddings)
+ torch.manual_seed(SEED)
+ y2, p2 = cust(x, paddings, dropout_rate=dropout_rate)
+
+ assert_close(y1, y2, atol=0, rtol=0)
+ assert_close(p1, p2, atol=0, rtol=0)
+
+ if mode == 'eval': # one extra test: omit dropout at eval
torch.manual_seed(SEED)
- cust = CustomModel(
- CustomConfig(
- num_encoder_layers=N_LAYERS
- )
- ).to(DEVICE)
-
- orig.load_state_dict(cust.state_dict()) # sync weights
-
- x = torch.randn(B, T, device=DEVICE)
- paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE)
-
- for mode in ('train', 'eval'):
- getattr(orig, mode)()
- getattr(cust, mode)()
-
- torch.manual_seed(SEED)
- y1, p1 = orig(x, paddings)
- torch.manual_seed(SEED)
- y2, p2 = cust(x, paddings, dropout_rate=dropout_rate)
-
- assert_close(y1, y2, atol=0, rtol=0)
- assert_close(p1, p2, atol=0, rtol=0)
-
- if mode == 'eval': # one extra test: omit dropout at eval
- torch.manual_seed(SEED)
- y2, p2 = cust(x, paddings)
- assert_close(y1, y2, atol=0, rtol=0)
- assert_close(p1, p2, atol=0, rtol=0)
+ y2, p2 = cust(x, paddings)
+ assert_close(y1, y2, atol=0, rtol=0)
+ assert_close(p1, p2, atol=0, rtol=0)
+
if __name__ == '__main__':
- absltest.main()
+ absltest.main()
diff --git a/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py b/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py
index 610e6b18e..7d6a94592 100644
--- a/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py
+++ b/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py
@@ -28,7 +28,7 @@
DEVICE = 'cuda'
TORCH_COMPILE = False
-os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:8"
+os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms(mode=True)
@@ -37,83 +37,88 @@
class DeepSpeechEquivalenceTest(parameterized.TestCase):
- @parameterized.named_parameters(
- dict(testcase_name='p=0.0', dropout_rate=0.0),
- dict(testcase_name='p=0.2', dropout_rate=0.2),
- dict(testcase_name='p=0.7', dropout_rate=0.7),
- dict(testcase_name='p=1.0', dropout_rate=1.0),
- )
- def test_forward(self, dropout_rate):
- """Test different dropout_rate values."""
-
- torch.manual_seed(SEED)
- orig = OriginalModel(
- OriginalConfig(
+ @parameterized.named_parameters(
+ dict(testcase_name='p=0.0', dropout_rate=0.0),
+ dict(testcase_name='p=0.2', dropout_rate=0.2),
+ dict(testcase_name='p=0.7', dropout_rate=0.7),
+ dict(testcase_name='p=1.0', dropout_rate=1.0),
+ )
+ def test_forward(self, dropout_rate):
+ """Test different dropout_rate values."""
+
+ torch.manual_seed(SEED)
+ orig = OriginalModel(
+ OriginalConfig(
num_lstm_layers=2,
num_ffn_layers=2,
input_dropout_rate=dropout_rate,
feed_forward_dropout_rate=dropout_rate,
)).to(DEVICE)
-
- torch.manual_seed(SEED)
- cust = CustomModel(CustomConfig(
- num_lstm_layers=2,
- num_ffn_layers=2,
- )).to(DEVICE)
-
- orig.load_state_dict(cust.state_dict()) # sync weights
-
- if TORCH_COMPILE:
- orig = torch.compile(orig); cust = torch.compile(cust)
-
- x = torch.randn(B, T, device=DEVICE)
- paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE)
-
- for mode in ('train', 'eval'):
- getattr(orig, mode)()
- getattr(cust, mode)()
-
- torch.manual_seed(SEED)
- y1, p1 = orig(x, paddings)
-
- torch.manual_seed(SEED)
- y2, p2 = cust(x, paddings, dropout_rate=dropout_rate)
-
- assert_close(y1, y2, atol=0, rtol=0)
- assert_close(p1, p2, atol=0, rtol=0)
-
- if mode == 'eval': # one extra test: omit dropout at eval
- torch.manual_seed(SEED)
- y2, p2 = cust(x, paddings)
- assert_close(y1, y2, atol=0, rtol=0)
- assert_close(p1, p2, atol=0, rtol=0)
-
-
- @parameterized.named_parameters(
- dict(testcase_name=''),
- )
- def test_default_dropout(self):
- """Test default dropout_rate."""
- torch.manual_seed(SEED)
- orig = OriginalModel(OriginalConfig( num_lstm_layers=2, num_ffn_layers=2)).to(DEVICE)
- torch.manual_seed(SEED)
- cust = CustomModel(CustomConfig( num_lstm_layers=2, num_ffn_layers=2)).to(DEVICE)
- orig.load_state_dict(cust.state_dict()) # sync weights
+ torch.manual_seed(SEED)
+ cust = CustomModel(CustomConfig(
+ num_lstm_layers=2,
+ num_ffn_layers=2,
+ )).to(DEVICE)
- if TORCH_COMPILE:
- orig = torch.compile(orig); cust = torch.compile(cust)
+ orig.load_state_dict(cust.state_dict()) # sync weights
- x = torch.randn(B, T, device=DEVICE)
- paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE)
+ if TORCH_COMPILE:
+ orig = torch.compile(orig)
+ cust = torch.compile(cust)
+
+ x = torch.randn(B, T, device=DEVICE)
+ paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE)
+
+ for mode in ('train', 'eval'):
+ getattr(orig, mode)()
+ getattr(cust, mode)()
+
+ torch.manual_seed(SEED)
+ y1, p1 = orig(x, paddings)
+
+ torch.manual_seed(SEED)
+ y2, p2 = cust(x, paddings, dropout_rate=dropout_rate)
+
+ assert_close(y1, y2, atol=0, rtol=0)
+ assert_close(p1, p2, atol=0, rtol=0)
+
+ if mode == 'eval': # one extra test: omit dropout at eval
+ torch.manual_seed(SEED)
+ y2, p2 = cust(x, paddings)
+ assert_close(y1, y2, atol=0, rtol=0)
+ assert_close(p1, p2, atol=0, rtol=0)
+
+ @parameterized.named_parameters(
+ dict(testcase_name=''),)
+ def test_default_dropout(self):
+ """Test default dropout_rate."""
+
+ torch.manual_seed(SEED)
+ orig = OriginalModel(OriginalConfig(num_lstm_layers=2,
+ num_ffn_layers=2)).to(DEVICE)
+ torch.manual_seed(SEED)
+ cust = CustomModel(CustomConfig(num_lstm_layers=2,
+ num_ffn_layers=2)).to(DEVICE)
+ orig.load_state_dict(cust.state_dict()) # sync weights
+
+ if TORCH_COMPILE:
+ orig = torch.compile(orig)
+ cust = torch.compile(cust)
+
+ x = torch.randn(B, T, device=DEVICE)
+ paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE)
+
+ for mode in ('train', 'eval'):
+ getattr(orig, mode)()
+ getattr(cust, mode)()
+ torch.manual_seed(SEED)
+ y1, p1 = orig(x, paddings)
+ torch.manual_seed(SEED)
+ y2, p2 = cust(x, paddings)
+ assert_close(y1, y2, atol=0, rtol=0)
+ assert_close(p1, p2, atol=0, rtol=0)
- for mode in ('train', 'eval'):
- getattr(orig, mode)(); getattr(cust, mode)()
- torch.manual_seed(SEED); y1, p1 = orig(x, paddings)
- torch.manual_seed(SEED); y2, p2 = cust(x, paddings)
- assert_close(y1, y2, atol=0, rtol=0)
- assert_close(p1, p2, atol=0, rtol=0)
-
if __name__ == '__main__':
- absltest.main()
+ absltest.main()
diff --git a/tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py b/tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py
index f5c7e992b..3b3feb680 100644
--- a/tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py
+++ b/tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py
@@ -18,7 +18,7 @@
from algoperf.workloads.ogbg.ogbg_pytorch.models_dropout import \
GNN as CustomModel
-B, N, E = 8, 20, 40 # graphs, nodes/graph, edges/graph
+B, N, E = 8, 20, 40 # graphs, nodes/graph, edges/graph
NODE_FDIM, EDGE_FDIM = 9, 3 # expected feature dims
DEVICE = 'cuda'
@@ -30,80 +30,89 @@
def _rand_graph():
- total_nodes, total_edges = B * N, B * E
- nodes = torch.randn(total_nodes, NODE_FDIM, device=DEVICE)
- edges = torch.randn(total_edges, EDGE_FDIM, device=DEVICE)
- senders, receivers = [], []
- for i in range(B):
- offset = i * N
- s = torch.randint(N, (E,), device=DEVICE) + offset
- r = torch.randint(N, (E,), device=DEVICE) + offset
- senders.append(s), receivers.append(r)
- senders = torch.cat(senders); receivers = torch.cat(receivers)
- n_node = torch.full((B,), N, device=DEVICE, dtype=torch.int32)
- n_edge = torch.full((B,), E, device=DEVICE, dtype=torch.int32)
- return GraphsTuple(nodes, edges, receivers, senders, None, n_node, n_edge)
+ total_nodes, total_edges = B * N, B * E
+ nodes = torch.randn(total_nodes, NODE_FDIM, device=DEVICE)
+ edges = torch.randn(total_edges, EDGE_FDIM, device=DEVICE)
+ senders, receivers = [], []
+ for i in range(B):
+ offset = i * N
+ s = torch.randint(N, (E,), device=DEVICE) + offset
+ r = torch.randint(N, (E,), device=DEVICE) + offset
+ senders.append(s), receivers.append(r)
+ senders = torch.cat(senders)
+ receivers = torch.cat(receivers)
+ n_node = torch.full((B,), N, device=DEVICE, dtype=torch.int32)
+ n_edge = torch.full((B,), E, device=DEVICE, dtype=torch.int32)
+ return GraphsTuple(nodes, edges, receivers, senders, None, n_node, n_edge)
class GNNEquivalenceTest(parameterized.TestCase):
- @parameterized.named_parameters(
- dict(testcase_name='0.0', dropout_rate=0.0),
- dict(testcase_name='0.2', dropout_rate=0.2),
- dict(testcase_name='0.7', dropout_rate=0.7),
- dict(testcase_name='1.0', dropout_rate=1.0),
- )
- def test_forward(self, dropout_rate):
- """Test different dropout_rates."""
+ @parameterized.named_parameters(
+ dict(testcase_name='0.0', dropout_rate=0.0),
+ dict(testcase_name='0.2', dropout_rate=0.2),
+ dict(testcase_name='0.7', dropout_rate=0.7),
+ dict(testcase_name='1.0', dropout_rate=1.0),
+ )
+ def test_forward(self, dropout_rate):
+ """Test different dropout_rates."""
- orig = OriginalModel(dropout_rate=dropout_rate).to(DEVICE)
- cust = CustomModel().to(DEVICE)
- orig.load_state_dict(cust.state_dict()) # sync weights
+ orig = OriginalModel(dropout_rate=dropout_rate).to(DEVICE)
+ cust = CustomModel().to(DEVICE)
+ orig.load_state_dict(cust.state_dict()) # sync weights
- graph = _rand_graph()
+ graph = _rand_graph()
- for mode in ('train', 'eval'):
- getattr(orig, mode)()
- getattr(cust, mode)()
+ for mode in ('train', 'eval'):
+ getattr(orig, mode)()
+ getattr(cust, mode)()
- torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED)
- y1 = orig(graph)
+ torch.manual_seed(SEED)
+ random.seed(SEED)
+ np.random.seed(SEED)
+ y1 = orig(graph)
- torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED)
- y2 = cust(graph, dropout_rate=dropout_rate)
+ torch.manual_seed(SEED)
+ random.seed(SEED)
+ np.random.seed(SEED)
+ y2 = cust(graph, dropout_rate=dropout_rate)
- assert_close(y1, y2, atol=0, rtol=0)
+ assert_close(y1, y2, atol=0, rtol=0)
- if mode == 'eval': # one extra test: omit dropout at eval
- torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED)
- y2 = cust(graph)
- assert_close(y1, y2, atol=0, rtol=0)
+ if mode == 'eval': # one extra test: omit dropout at eval
+ torch.manual_seed(SEED)
+ random.seed(SEED)
+ np.random.seed(SEED)
+ y2 = cust(graph)
+ assert_close(y1, y2, atol=0, rtol=0)
+ @parameterized.named_parameters(
+ dict(testcase_name=''),)
+ def test_default_dropout(self):
+ """Test default dropout_rate."""
- @parameterized.named_parameters(
- dict(testcase_name=''),
- )
- def test_default_dropout(self):
- """Test default dropout_rate."""
-
- orig = OriginalModel().to(DEVICE)
- cust = CustomModel().to(DEVICE)
- orig.load_state_dict(cust.state_dict()) # sync weights
+ orig = OriginalModel().to(DEVICE)
+ cust = CustomModel().to(DEVICE)
+ orig.load_state_dict(cust.state_dict()) # sync weights
- graph = _rand_graph()
+ graph = _rand_graph()
- for mode in ('train', 'eval'):
- getattr(orig, mode)()
- getattr(cust, mode)()
+ for mode in ('train', 'eval'):
+ getattr(orig, mode)()
+ getattr(cust, mode)()
- torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED)
- y1 = orig(graph)
+ torch.manual_seed(SEED)
+ random.seed(SEED)
+ np.random.seed(SEED)
+ y1 = orig(graph)
- torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED)
- y2 = cust(graph)
+ torch.manual_seed(SEED)
+ random.seed(SEED)
+ np.random.seed(SEED)
+ y2 = cust(graph)
- assert_close(y1, y2, atol=0, rtol=0)
+ assert_close(y1, y2, atol=0, rtol=0)
if __name__ == '__main__':
- absltest.main()
+ absltest.main()
diff --git a/tests/dropout_fix/wmt_pytorch/test_model_equivalence.py b/tests/dropout_fix/wmt_pytorch/test_model_equivalence.py
index 918043cfd..03f289a68 100644
--- a/tests/dropout_fix/wmt_pytorch/test_model_equivalence.py
+++ b/tests/dropout_fix/wmt_pytorch/test_model_equivalence.py
@@ -29,85 +29,93 @@
def _rand_tokens(bs, seqlen):
- return torch.randint(1, NTOK, (bs, seqlen), device=DEVICE)
+ return torch.randint(1, NTOK, (bs, seqlen), device=DEVICE)
class TransformerEquivalenceTest(parameterized.TestCase):
- @parameterized.named_parameters(
- # NOTE: removed dropout=1.0 since it will generate nan in scaled_dot_product_attention
- dict(testcase_name="0.0", dropout_rate=0.0, compile=False),
- dict(testcase_name="0.2", dropout_rate=0.2, compile=False),
- dict(testcase_name="0.7", dropout_rate=0.7, compile=False),
- dict(testcase_name="p=0.0_compile", dropout_rate=0.0, compile=True),
- dict(testcase_name="p=0.2_compile", dropout_rate=0.2, compile=True),
- dict(testcase_name="p=0.7_compile", dropout_rate=0.7, compile=True),
- )
- def test_dropout_value(self, dropout_rate, compile):
-
- orig = OriginalModel(
- dropout_rate=dropout_rate,
- attention_dropout_rate=dropout_rate
- ).to(DEVICE)
- cust = CustomModel().to(DEVICE)
-
- orig.load_state_dict(cust.state_dict()) # sync weights
-
- if compile:
- orig = torch.compile(orig)
- cust = torch.compile(cust)
-
- src = _rand_tokens(B, SRC_LEN)
- tgt = _rand_tokens(B, TGT_LEN)
-
- for mode in ("train", "eval"):
- getattr(orig, mode)()
- getattr(cust, mode)()
-
- torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED)
- y1 = orig(src, tgt)
-
- torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED)
- y2 = cust(src, tgt, dropout_rate=dropout_rate)
-
- assert_close(y1, y2, atol=0, rtol=0)
-
- if mode == 'eval': # one extra test: omit dropout at eval
- torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED)
- y2 = cust(src, tgt)
- assert_close(y1, y2, atol=0, rtol=0)
-
-
- @parameterized.named_parameters(
- dict(testcase_name="default", compile=False),
- dict(testcase_name="default_compile", compile=True),
- )
- def test_default(self, compile):
-
- orig = OriginalModel().to(DEVICE)
- cust = CustomModel().to(DEVICE)
-
- orig.load_state_dict(cust.state_dict()) # sync weights
-
- if compile:
- orig = torch.compile(orig)
- cust = torch.compile(cust)
-
- src = _rand_tokens(B, SRC_LEN)
- tgt = _rand_tokens(B, TGT_LEN)
-
- for mode in ("train", "eval"):
- getattr(orig, mode)()
- getattr(cust, mode)()
-
- torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED)
- y1 = orig(src, tgt)
-
- torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED)
- y2 = cust(src, tgt)
-
- assert_close(y1, y2, atol=0, rtol=0)
-
+ @parameterized.named_parameters(
+ # NOTE: removed dropout=1.0 since it will generate nan in scaled_dot_product_attention
+ dict(testcase_name="0.0", dropout_rate=0.0, compile=False),
+ dict(testcase_name="0.2", dropout_rate=0.2, compile=False),
+ dict(testcase_name="0.7", dropout_rate=0.7, compile=False),
+ dict(testcase_name="p=0.0_compile", dropout_rate=0.0, compile=True),
+ dict(testcase_name="p=0.2_compile", dropout_rate=0.2, compile=True),
+ dict(testcase_name="p=0.7_compile", dropout_rate=0.7, compile=True),
+ )
+ def test_dropout_value(self, dropout_rate, compile):
+
+ orig = OriginalModel(
+ dropout_rate=dropout_rate,
+ attention_dropout_rate=dropout_rate).to(DEVICE)
+ cust = CustomModel().to(DEVICE)
+
+ orig.load_state_dict(cust.state_dict()) # sync weights
+
+ if compile:
+ orig = torch.compile(orig)
+ cust = torch.compile(cust)
+
+ src = _rand_tokens(B, SRC_LEN)
+ tgt = _rand_tokens(B, TGT_LEN)
+
+ for mode in ("train", "eval"):
+ getattr(orig, mode)()
+ getattr(cust, mode)()
+
+ torch.manual_seed(SEED)
+ random.seed(SEED)
+ np.random.seed(SEED)
+ y1 = orig(src, tgt)
+
+ torch.manual_seed(SEED)
+ random.seed(SEED)
+ np.random.seed(SEED)
+ y2 = cust(src, tgt, dropout_rate=dropout_rate)
+
+ assert_close(y1, y2, atol=0, rtol=0)
+
+ if mode == 'eval': # one extra test: omit dropout at eval
+ torch.manual_seed(SEED)
+ random.seed(SEED)
+ np.random.seed(SEED)
+ y2 = cust(src, tgt)
+ assert_close(y1, y2, atol=0, rtol=0)
+
+ @parameterized.named_parameters(
+ dict(testcase_name="default", compile=False),
+ dict(testcase_name="default_compile", compile=True),
+ )
+ def test_default(self, compile):
+
+ orig = OriginalModel().to(DEVICE)
+ cust = CustomModel().to(DEVICE)
+
+ orig.load_state_dict(cust.state_dict()) # sync weights
+
+ if compile:
+ orig = torch.compile(orig)
+ cust = torch.compile(cust)
+
+ src = _rand_tokens(B, SRC_LEN)
+ tgt = _rand_tokens(B, TGT_LEN)
+
+ for mode in ("train", "eval"):
+ getattr(orig, mode)()
+ getattr(cust, mode)()
+
+ torch.manual_seed(SEED)
+ random.seed(SEED)
+ np.random.seed(SEED)
+ y1 = orig(src, tgt)
+
+ torch.manual_seed(SEED)
+ random.seed(SEED)
+ np.random.seed(SEED)
+ y2 = cust(src, tgt)
+
+ assert_close(y1, y2, atol=0, rtol=0)
+
if __name__ == "__main__":
- absltest.main()
+ absltest.main()
From badf12453a56b078c3b156c0b850ec5e8158bb81 Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Wed, 18 Jun 2025 21:23:03 +0000
Subject: [PATCH 070/123] fix test
---
tests/reference_algorithm_tests.py | 7 +++----
1 file changed, 3 insertions(+), 4 deletions(-)
diff --git a/tests/reference_algorithm_tests.py b/tests/reference_algorithm_tests.py
index c4ca514a8..58a4a5ddc 100644
--- a/tests/reference_algorithm_tests.py
+++ b/tests/reference_algorithm_tests.py
@@ -184,12 +184,11 @@ def __init__(self):
if 'librispeech' in workload_name:
self.tokenizer = _FakeTokenizer()
- def init_model_fn(self, rng, dropout_rate=None, aux_dropout_rate=None):
+ def init_model_fn(self, rng):
# pylint: disable=line-too-long
if not (FLAGS.identical and
os.path.exists(f'tests/modeldiffs/{workload_name}/compare.py')):
- return super().init_model_fn(
- rng, dropout_rate=dropout_rate, aux_dropout_rate=aux_dropout_rate)
+ return super().init_model_fn(rng)
if framework == 'jax':
compare_module = importlib.import_module(
f'tests.modeldiffs.{workload_name}.compare')
@@ -201,7 +200,7 @@ def init_model_fn(self, rng, dropout_rate=None, aux_dropout_rate=None):
return (FrozenDict(**jax_utils.replicate(jax_params)),
FrozenDict(**jax_utils.replicate(model_state))
if model_state is not None else model_state)
- return super().init_model_fn([0], dropout_rate=0.0, aux_dropout_rate=0.0)
+ return super().init_model_fn([0])
@property
def num_eval_train_examples(self):
From eff3ea19572d0c3f3497bcf997948d7b3cbe7c69 Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Wed, 18 Jun 2025 21:28:18 +0000
Subject: [PATCH 071/123] fix lint errors
---
algoperf/workloads/fastmri/fastmri_jax/models.py | 1 -
algoperf/workloads/imagenet_vit/imagenet_jax/workload.py | 2 +-
.../workloads/librispeech_deepspeech/librispeech_jax/models.py | 3 ++-
algoperf/workloads/ogbg/ogbg_jax/models.py | 2 +-
algoperf/workloads/ogbg/ogbg_jax/workload.py | 2 +-
5 files changed, 5 insertions(+), 5 deletions(-)
diff --git a/algoperf/workloads/fastmri/fastmri_jax/models.py b/algoperf/workloads/fastmri/fastmri_jax/models.py
index 70c7fc4a5..a5fe060b9 100644
--- a/algoperf/workloads/fastmri/fastmri_jax/models.py
+++ b/algoperf/workloads/fastmri/fastmri_jax/models.py
@@ -13,7 +13,6 @@
github.com/facebookresearch/fastMRI/tree/main/fastmri/data
"""
import functools
-from typing import Optional
import flax.linen as nn
import jax
diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py
index d0fb4fd72..ab9df0f62 100644
--- a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py
+++ b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py
@@ -1,6 +1,6 @@
"""ImageNet workload implemented in Jax."""
-from typing import Dict, Optional, Tuple
+from typing import Dict, Tuple
from flax import jax_utils
from flax import linen as nn
diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py
index 1bd998027..fab0b3259 100644
--- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py
+++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py
@@ -489,7 +489,8 @@ def __call__(self, inputs, input_paddings, train, dropout_rate=DROPOUT_RATE):
# Subsample input by a factor of 4 by performing strided convolutions.
outputs, output_paddings = Subsample(
- config=config)(outputs, output_paddings, train, dropout_rate=dropout_rate)
+ config=config)(outputs, output_paddings, train,
+ dropout_rate=dropout_rate)
# Run the lstm layers.
for _ in range(config.num_lstm_layers):
diff --git a/algoperf/workloads/ogbg/ogbg_jax/models.py b/algoperf/workloads/ogbg/ogbg_jax/models.py
index 06eef6187..8524bb60e 100644
--- a/algoperf/workloads/ogbg/ogbg_jax/models.py
+++ b/algoperf/workloads/ogbg/ogbg_jax/models.py
@@ -1,6 +1,6 @@
# Forked from the init2winit implementation here
# https://github.com/google/init2winit/blob/master/init2winit/model_lib/gnn.py.
-from typing import Optional, Tuple
+from typing import Tuple
from flax import linen as nn
import jax.numpy as jnp
diff --git a/algoperf/workloads/ogbg/ogbg_jax/workload.py b/algoperf/workloads/ogbg/ogbg_jax/workload.py
index 04a9bce2e..0535aea83 100644
--- a/algoperf/workloads/ogbg/ogbg_jax/workload.py
+++ b/algoperf/workloads/ogbg/ogbg_jax/workload.py
@@ -1,6 +1,6 @@
"""OGBG workload implemented in Jax."""
import functools
-from typing import Any, Dict, Optional, Tuple
+from typing import Any, Dict, Tuple
from flax import jax_utils
import jax
From f7fd6c7452ead2771a9ebc6cd1e50ba99d5f3d9a Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Wed, 18 Jun 2025 21:42:36 +0000
Subject: [PATCH 072/123] formatting
---
algoperf/jax_utils.py | 92 +++++++++----------
.../librispeech_jax/models.py | 2 +-
2 files changed, 45 insertions(+), 49 deletions(-)
diff --git a/algoperf/jax_utils.py b/algoperf/jax_utils.py
index 369eb1b1a..214a178c6 100644
--- a/algoperf/jax_utils.py
+++ b/algoperf/jax_utils.py
@@ -13,7 +13,7 @@
# Custom Layers
class Dropout(Module):
- """Create a dropout layer.
+ """Create a dropout layer.
Forked from https://flax-linen.readthedocs.io/en/latest/_modules/flax/linen/stochastic.html#Dropout.
The reference dropout implementation is modified support changes to dropout rate during training by:
1) adding rate argument to the __call__ method.
@@ -53,21 +53,21 @@ class Dropout(Module):
rng_collection: the rng collection name to use when requesting an rng key.
"""
- rate: float | None = None
- broadcast_dims: Sequence[int] = ()
- deterministic: bool | None = None
- rng_collection: str = "dropout"
- legacy: bool = True
-
- @compact
- def __call__(
- self,
- inputs,
- deterministic: bool | None = None,
- rate: float | None = None,
- rng: PRNGKey | None = None,
- ):
- """Applies a random dropout mask to the input.
+ rate: float | None = None
+ broadcast_dims: Sequence[int] = ()
+ deterministic: bool | None = None
+ rng_collection: str = "dropout"
+ legacy: bool = True
+
+ @compact
+ def __call__(
+ self,
+ inputs,
+ deterministic: bool | None = None,
+ rate: float | None = None,
+ rng: PRNGKey | None = None,
+ ):
+ """Applies a random dropout mask to the input.
Args:
inputs: the inputs that should be randomly masked.
@@ -81,44 +81,40 @@ def __call__(
Returns:
The masked inputs reweighted to preserve mean.
"""
- deterministic = merge_param("deterministic",
- self.deterministic,
- deterministic)
+ deterministic = merge_param("deterministic", self.deterministic, deterministic)
- # Override self.rate if rate is passed to __call__
- if rate is None:
- rate = self.rate
+ # Override self.rate if rate is passed to __call__
+ if rate is None:
+ rate = self.rate
- if self.legacy:
- if rate == 0.0:
- return inputs
+ if self.legacy:
+ if rate == 0.0:
+ return inputs
- # Prevent gradient NaNs in 1.0 edge-case.
- if rate == 1.0:
- return jnp.zeros_like(inputs)
+ # Prevent gradient NaNs in 1.0 edge-case.
+ if rate == 1.0:
+ return jnp.zeros_like(inputs)
- if deterministic:
- return inputs
+ if deterministic:
+ return inputs
- keep_prob = 1.0 - rate
- if rng is None:
- rng = self.make_rng(self.rng_collection)
- broadcast_shape = list(inputs.shape)
- for dim in self.broadcast_dims:
- broadcast_shape[dim] = 1
- mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape)
- mask = jnp.broadcast_to(mask, inputs.shape)
- return lax.select(mask, inputs, jnp.zeros_like(inputs))
+ keep_prob = 1.0 - rate
+ if rng is None:
+ rng = self.make_rng(self.rng_collection)
+ broadcast_shape = list(inputs.shape)
+ for dim in self.broadcast_dims:
+ broadcast_shape[dim] = 1
+ mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape)
+ mask = jnp.broadcast_to(mask, inputs.shape)
+ return lax.select(mask, inputs, jnp.zeros_like(inputs))
# Utilities for debugging
def print_jax_model_summary(model, fake_inputs):
- """Prints a summary of the jax module."""
- tabulate_fn = nn.tabulate(
- model,
- jax.random.PRNGKey(0),
- console_kwargs={
- "force_terminal": False, "force_jupyter": False, "width": 240
- },
- )
- print(tabulate_fn(fake_inputs, train=False))
+ """Prints a summary of the jax module."""
+ tabulate_fn = nn.tabulate(
+ model,
+ jax.random.PRNGKey(0),
+ console_kwargs={"force_terminal": False, "force_jupyter": False, "width": 240},
+ )
+ print(tabulate_fn(fake_inputs, train=False))
diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py
index fab0b3259..262fc1a95 100644
--- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py
+++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py
@@ -489,7 +489,7 @@ def __call__(self, inputs, input_paddings, train, dropout_rate=DROPOUT_RATE):
# Subsample input by a factor of 4 by performing strided convolutions.
outputs, output_paddings = Subsample(
- config=config)(outputs, output_paddings, train,
+ config=config)(outputs, output_paddings, train,
dropout_rate=dropout_rate)
# Run the lstm layers.
From 8fc4cc5cd7914d698271bd50583432832e8dc98c Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Wed, 18 Jun 2025 21:44:35 +0000
Subject: [PATCH 073/123] fix spacing issues
---
algoperf/jax_utils.py | 92 ++++++++++++++++++++++---------------------
1 file changed, 48 insertions(+), 44 deletions(-)
diff --git a/algoperf/jax_utils.py b/algoperf/jax_utils.py
index 214a178c6..369eb1b1a 100644
--- a/algoperf/jax_utils.py
+++ b/algoperf/jax_utils.py
@@ -13,7 +13,7 @@
# Custom Layers
class Dropout(Module):
- """Create a dropout layer.
+ """Create a dropout layer.
Forked from https://flax-linen.readthedocs.io/en/latest/_modules/flax/linen/stochastic.html#Dropout.
The reference dropout implementation is modified support changes to dropout rate during training by:
1) adding rate argument to the __call__ method.
@@ -53,21 +53,21 @@ class Dropout(Module):
rng_collection: the rng collection name to use when requesting an rng key.
"""
- rate: float | None = None
- broadcast_dims: Sequence[int] = ()
- deterministic: bool | None = None
- rng_collection: str = "dropout"
- legacy: bool = True
-
- @compact
- def __call__(
- self,
- inputs,
- deterministic: bool | None = None,
- rate: float | None = None,
- rng: PRNGKey | None = None,
- ):
- """Applies a random dropout mask to the input.
+ rate: float | None = None
+ broadcast_dims: Sequence[int] = ()
+ deterministic: bool | None = None
+ rng_collection: str = "dropout"
+ legacy: bool = True
+
+ @compact
+ def __call__(
+ self,
+ inputs,
+ deterministic: bool | None = None,
+ rate: float | None = None,
+ rng: PRNGKey | None = None,
+ ):
+ """Applies a random dropout mask to the input.
Args:
inputs: the inputs that should be randomly masked.
@@ -81,40 +81,44 @@ def __call__(
Returns:
The masked inputs reweighted to preserve mean.
"""
- deterministic = merge_param("deterministic", self.deterministic, deterministic)
+ deterministic = merge_param("deterministic",
+ self.deterministic,
+ deterministic)
- # Override self.rate if rate is passed to __call__
- if rate is None:
- rate = self.rate
+ # Override self.rate if rate is passed to __call__
+ if rate is None:
+ rate = self.rate
- if self.legacy:
- if rate == 0.0:
- return inputs
+ if self.legacy:
+ if rate == 0.0:
+ return inputs
- # Prevent gradient NaNs in 1.0 edge-case.
- if rate == 1.0:
- return jnp.zeros_like(inputs)
+ # Prevent gradient NaNs in 1.0 edge-case.
+ if rate == 1.0:
+ return jnp.zeros_like(inputs)
- if deterministic:
- return inputs
+ if deterministic:
+ return inputs
- keep_prob = 1.0 - rate
- if rng is None:
- rng = self.make_rng(self.rng_collection)
- broadcast_shape = list(inputs.shape)
- for dim in self.broadcast_dims:
- broadcast_shape[dim] = 1
- mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape)
- mask = jnp.broadcast_to(mask, inputs.shape)
- return lax.select(mask, inputs, jnp.zeros_like(inputs))
+ keep_prob = 1.0 - rate
+ if rng is None:
+ rng = self.make_rng(self.rng_collection)
+ broadcast_shape = list(inputs.shape)
+ for dim in self.broadcast_dims:
+ broadcast_shape[dim] = 1
+ mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape)
+ mask = jnp.broadcast_to(mask, inputs.shape)
+ return lax.select(mask, inputs, jnp.zeros_like(inputs))
# Utilities for debugging
def print_jax_model_summary(model, fake_inputs):
- """Prints a summary of the jax module."""
- tabulate_fn = nn.tabulate(
- model,
- jax.random.PRNGKey(0),
- console_kwargs={"force_terminal": False, "force_jupyter": False, "width": 240},
- )
- print(tabulate_fn(fake_inputs, train=False))
+ """Prints a summary of the jax module."""
+ tabulate_fn = nn.tabulate(
+ model,
+ jax.random.PRNGKey(0),
+ console_kwargs={
+ "force_terminal": False, "force_jupyter": False, "width": 240
+ },
+ )
+ print(tabulate_fn(fake_inputs, train=False))
From 99c31114af33c6bd2ea5e9fcdc400131dc17bc78 Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Wed, 18 Jun 2025 21:53:37 +0000
Subject: [PATCH 074/123] formatting
---
pyproject.toml | 1 +
1 file changed, 1 insertion(+)
diff --git a/pyproject.toml b/pyproject.toml
index 4e15e4400..1daa72848 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -131,6 +131,7 @@ pytorch_gpu = [
based_on_style = "yapf"
each_dict_entry_on_separate_line = false
split_all_top_level_comma_separated_values = true
+column_limit = 80
[tool.yapfignore]
ignore_patterns = ["algoperf/_version.py"]
From c2f4ed0eb0fe5ceebcd9a21c7a857de644654f2e Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Wed, 18 Jun 2025 22:05:33 +0000
Subject: [PATCH 075/123] formatting
---
algoperf/jax_utils.py | 30 ++++++++++++++++++------------
submission_runner.py | 3 +--
2 files changed, 19 insertions(+), 14 deletions(-)
diff --git a/algoperf/jax_utils.py b/algoperf/jax_utils.py
index 369eb1b1a..467606241 100644
--- a/algoperf/jax_utils.py
+++ b/algoperf/jax_utils.py
@@ -13,11 +13,15 @@
# Custom Layers
class Dropout(Module):
+ # pylint: disable=line-too-long
"""Create a dropout layer.
- Forked from https://flax-linen.readthedocs.io/en/latest/_modules/flax/linen/stochastic.html#Dropout.
- The reference dropout implementation is modified support changes to dropout rate during training by:
+ Forked from
+ https://flax-linen.readthedocs.io/en/latest/_modules/flax/linen/stochastic.html#Dropout.
+ The reference dropout implementation is modified support changes
+ to dropout rate during training by:
1) adding rate argument to the __call__ method.
- 2) removing the if-else condition to check for edge cases, which will trigger a recompile for jitted code.
+ 2) removing the if-else condition to check for edge cases, which
+ will trigger a recompile for jitted code.
.. note::
When using :meth:`Module.apply() `, make sure
@@ -47,10 +51,11 @@ class Dropout(Module):
Attributes:
rate: the dropout probability. (_not_ the keep rate!)
broadcast_dims: dimensions that will share the same dropout mask
- deterministic: if false the inputs are scaled by ``1 / (1 - rate)`` and
- masked, whereas if true, no mask is applied and the inputs are returned as
- is.
- rng_collection: the rng collection name to use when requesting an rng key.
+ deterministic: if false the inputs are scaled by ``1 / (1 - rate)``
+ and masked, whereas if true, no mask is applied and the inputs are
+ returned as is.
+ rng_collection: the rng collection name to use when requesting an rng
+ key.
"""
rate: float | None = None
@@ -71,12 +76,13 @@ def __call__(
Args:
inputs: the inputs that should be randomly masked.
- deterministic: if false the inputs are scaled by ``1 / (1 - rate)`` and
- masked, whereas if true, no mask is applied and the inputs are returned
- as is.
+ deterministic: if false the inputs are scaled by ``1 / (1 - rate)``
+ and masked, whereas if true, no mask is applied and the inputs are
+ returned as is.
rate: the dropout probability. (_not_ the keep rate!)
- rng: an optional PRNGKey used as the random key, if not specified, one
- will be generated using ``make_rng`` with the ``rng_collection`` name.
+ rng: an optional PRNGKey used as the random key, if not specified,
+ one will be generated using ``make_rng`` with the
+ ``rng_collection`` name.
Returns:
The masked inputs reweighted to preserve mean.
diff --git a/submission_runner.py b/submission_runner.py
index d076a1043..221a7c21d 100644
--- a/submission_runner.py
+++ b/submission_runner.py
@@ -228,8 +228,7 @@ def train_once(
global_batch_size=global_batch_size)
logging.info('Initializing model.')
with profiler.profile('Initializing model'):
- model_params, model_state = workload.init_model_fn(
- model_init_rng)
+ model_params, model_state = workload.init_model_fn(model_init_rng)
if FLAGS.framework == 'pytorch' and FLAGS.torch_compile:
compile_error_workloads = [
'librispeech_conformer',
From b20f49dbee50e2c9b1114aa9529fb6a83a012e79 Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Thu, 19 Jun 2025 00:01:36 +0000
Subject: [PATCH 076/123] formatting
---
algoperf/pytorch_utils.py | 4 ++--
algoperf/workloads/fastmri/fastmri_pytorch/models.py | 1 -
algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py | 2 +-
.../workloads/imagenet_vit/imagenet_pytorch/workload.py | 2 +-
.../librispeech_conformer/librispeech_pytorch/models.py | 5 +++--
.../librispeech_deepspeech/librispeech_pytorch/models.py | 6 ++++--
6 files changed, 11 insertions(+), 9 deletions(-)
diff --git a/algoperf/pytorch_utils.py b/algoperf/pytorch_utils.py
index bae26dea0..f3d8782a4 100644
--- a/algoperf/pytorch_utils.py
+++ b/algoperf/pytorch_utils.py
@@ -89,8 +89,8 @@ def __init__(self):
super().__init__()
self._supports_custom_dropout = True
- def forward(self, input: Tensor, p: float) -> Tensor:
- return F.dropout(input, p, training=self.training)
+ def forward(self, x: Tensor, p: float) -> Tensor:
+ return F.dropout(x, p, training=self.training)
class CustomDropout2d(nn.Module):
diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/models.py b/algoperf/workloads/fastmri/fastmri_pytorch/models.py
index 3e7d7671c..8441f06c2 100644
--- a/algoperf/workloads/fastmri/fastmri_pytorch/models.py
+++ b/algoperf/workloads/fastmri/fastmri_pytorch/models.py
@@ -5,7 +5,6 @@
"""
from functools import partial
-from typing import Optional
import torch
from torch import nn
diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py
index 9453780d0..6aa3306ba 100644
--- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py
+++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py
@@ -14,7 +14,7 @@
from algoperf import init_utils
from algoperf import spec
-from algoperf.workloads.wmt.wmt_pytorch.models_dropout import \
+from algoperf.workloads.wmt.wmt_pytorch.models import \
MultiheadAttention
DROPOUT_RATE = 0.0
diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py
index e1c6844fe..d43e90e80 100644
--- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py
+++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py
@@ -1,7 +1,7 @@
"""ImageNet ViT workload implemented in PyTorch."""
import contextlib
-from typing import Dict, Optional, Tuple
+from typing import Dict, Tuple
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py
index 10b8e585a..d917151a3 100644
--- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py
+++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py
@@ -5,7 +5,7 @@
from dataclasses import dataclass
from functools import partial
import math
-from typing import Optional, Tuple
+from typing import Tuple
import torch
from torch import nn
@@ -475,7 +475,8 @@ def forward(self, inputs, input_paddings, dropout_rate=DROPOUT_RATE):
outputs, output_paddings = self.preprocessor(outputs, output_paddings)
if self.training and self.config.use_specaug:
outputs, output_paddings = self.specaug(outputs, output_paddings)
- outputs, output_paddings = self.subsample(outputs, output_paddings, dropout_rate)
+ outputs, output_paddings = self.subsample(outputs, output_paddings,
+ dropout_rate)
for conformer in self.conformers:
outputs = conformer(outputs, output_paddings, dropout_rate)
outputs = self.ln(outputs)
diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py
index 644c13a16..589793dbd 100644
--- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py
+++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py
@@ -4,7 +4,7 @@
from dataclasses import dataclass
import os
-from typing import Optional, Tuple
+from typing import Tuple
import torch
from torch import nn
@@ -358,7 +358,9 @@ def forward(self, inputs, input_paddings, dropout_rate=DROPOUT_RATE):
outputs, output_paddings = self.preprocessor(outputs, output_paddings)
if self.training and self.config.use_specaug:
outputs, output_paddings = self.specaug(outputs, output_paddings)
- outputs, output_paddings = self.subsample(outputs, output_paddings, dropout_rate)
+ outputs, output_paddings = self.subsample(outputs,
+ output_paddings,
+ dropout_rate)
for idx in range(self.config.num_lstm_layers):
if self.config.enable_residual_connections:
outputs = outputs + self.lstms[idx](outputs, output_paddings)
From 0ea37ee8977c5d33219baada62fdc60decfa5a8d Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Thu, 19 Jun 2025 00:02:05 +0000
Subject: [PATCH 077/123] fix
---
.../librispeech_conformer/librispeech_pytorch/models.py | 2 +-
.../librispeech_deepspeech/librispeech_pytorch/models.py | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py
index d917151a3..3a2eda4af 100644
--- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py
+++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py
@@ -475,7 +475,7 @@ def forward(self, inputs, input_paddings, dropout_rate=DROPOUT_RATE):
outputs, output_paddings = self.preprocessor(outputs, output_paddings)
if self.training and self.config.use_specaug:
outputs, output_paddings = self.specaug(outputs, output_paddings)
- outputs, output_paddings = self.subsample(outputs, output_paddings,
+ outputs, output_paddings = self.subsample(outputs, output_paddings,
dropout_rate)
for conformer in self.conformers:
outputs = conformer(outputs, output_paddings, dropout_rate)
diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py
index 589793dbd..3d8c000e1 100644
--- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py
+++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py
@@ -359,7 +359,7 @@ def forward(self, inputs, input_paddings, dropout_rate=DROPOUT_RATE):
if self.training and self.config.use_specaug:
outputs, output_paddings = self.specaug(outputs, output_paddings)
outputs, output_paddings = self.subsample(outputs,
- output_paddings,
+ output_paddings,
dropout_rate)
for idx in range(self.config.num_lstm_layers):
if self.config.enable_residual_connections:
From 594f285c00289e03a1452c2fb674ab917f51f466 Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Thu, 19 Jun 2025 00:20:35 +0000
Subject: [PATCH 078/123] pylint fixes
---
algoperf/pytorch_utils.py | 6 +++---
algoperf/workloads/imagenet_vit/imagenet_jax/workload.py | 4 +++-
2 files changed, 6 insertions(+), 4 deletions(-)
diff --git a/algoperf/pytorch_utils.py b/algoperf/pytorch_utils.py
index f3d8782a4..bac171ca9 100644
--- a/algoperf/pytorch_utils.py
+++ b/algoperf/pytorch_utils.py
@@ -7,7 +7,7 @@
import torch
from torch import Tensor
import torch.distributed as dist
-import torch.nn as nn
+from torch import nn
import torch.nn.functional as F
from algoperf import spec
@@ -100,8 +100,8 @@ def __init__(self):
super().__init__()
self._supports_custom_dropout = True
- def forward(self, input: Tensor, p: float) -> Tensor:
- return F.dropout2d(input, p, training=self.training)
+ def forward(self, x: Tensor, p: float) -> Tensor:
+ return F.dropout2d(x, p, training=self.training)
class SequentialWithDropout(nn.Sequential):
diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py
index ab9df0f62..914ab8f86 100644
--- a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py
+++ b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py
@@ -1,6 +1,6 @@
"""ImageNet workload implemented in Jax."""
-from typing import Dict, Tuple
+from typing import Dict, Optional, Tuple
from flax import jax_utils
from flax import linen as nn
@@ -54,10 +54,12 @@ def model_fn(
mode: spec.ForwardPassMode,
rng: spec.RandomState,
update_batch_norm: bool,
+ use_running_average_bn: Optional[bool] = None,
dropout_rate: float = models.DROPOUT_RATE
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del update_batch_norm
+ del use_running_average_bn
train = mode == spec.ForwardPassMode.TRAIN
logits = self._model.apply({'params': params},
augmented_and_preprocessed_input_batch['inputs'],
From f14ff8f7c0e391578e1d785f53cff51854de8979 Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Thu, 19 Jun 2025 00:24:55 +0000
Subject: [PATCH 079/123] isort fixes
---
algoperf/pytorch_utils.py | 2 +-
algoperf/workloads/imagenet_vit/imagenet_jax/workload.py | 2 +-
algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py | 3 +--
3 files changed, 3 insertions(+), 4 deletions(-)
diff --git a/algoperf/pytorch_utils.py b/algoperf/pytorch_utils.py
index bac171ca9..b81d2969a 100644
--- a/algoperf/pytorch_utils.py
+++ b/algoperf/pytorch_utils.py
@@ -5,9 +5,9 @@
import jax
import tensorflow as tf
import torch
+from torch import nn
from torch import Tensor
import torch.distributed as dist
-from torch import nn
import torch.nn.functional as F
from algoperf import spec
diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py
index 914ab8f86..0e320b9b9 100644
--- a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py
+++ b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py
@@ -1,6 +1,6 @@
"""ImageNet workload implemented in Jax."""
-from typing import Dict, Optional, Tuple
+from typing import Dict, Optional, Tuple
from flax import jax_utils
from flax import linen as nn
diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py
index 6aa3306ba..cb503cd9f 100644
--- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py
+++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py
@@ -14,8 +14,7 @@
from algoperf import init_utils
from algoperf import spec
-from algoperf.workloads.wmt.wmt_pytorch.models import \
- MultiheadAttention
+from algoperf.workloads.wmt.wmt_pytorch.models import MultiheadAttention
DROPOUT_RATE = 0.0
From 2a8586a34bdd0baa8f641d2bc52fa37805186ca7 Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Thu, 19 Jun 2025 00:36:48 +0000
Subject: [PATCH 080/123] pylint fixes
---
.../librispeech_deepspeech/librispeech_pytorch/workload.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py
index 7640d69a5..0b9ce1e3c 100644
--- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py
+++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py
@@ -65,6 +65,7 @@ def model_fn(
dropout_rate: float = models.DROPOUT_RATE
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
# override super method, changing only the default dropout_rate
+ # pylint: disable=useless-parent-delegation
return super().model_fn(params,
augmented_and_preprocessed_input_batch,
model_state,
From ad36a7c3a409146c6bc6ebb9301b7b9beb1a6716 Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Sat, 21 Jun 2025 03:06:36 +0000
Subject: [PATCH 081/123] add dropout tests
---
algoperf/jax_utils.py | 4 +-
.../criteo1tb/criteo1tb_jax/models_ref.py | 219 ++++++
.../fastmri/fastmri_jax/models_ref.py | 220 ++++++
.../imagenet_jax/models_ref.py | 132 ++++
.../imagenet_vit/imagenet_jax/models_ref.py | 235 ++++++
.../librispeech_jax/models_ref.py | 712 ++++++++++++++++++
.../librispeech_jax/models_ref.py | 525 +++++++++++++
.../workloads/ogbg/ogbg_jax/models_ref.py | 88 +++
algoperf/workloads/wmt/wmt_jax/models_ref.py | 604 +++++++++++++++
.../criteo1tb_jax/test_model_equivalence.py | 156 ++++
.../fastmri_jax/test_model_equivalence.py | 130 ++++
.../test_model_equivalence.py | 138 ++++
.../test_model_equivalence.py | 84 +++
.../test_model_equivalence.py | 124 +++
.../test_model_equivalence.py | 118 +++
.../wmt_pytorch_jax/test_model_equivalence.py | 121 +++
tests/test_jax_utils.py | 237 ++++++
17 files changed, 3845 insertions(+), 2 deletions(-)
create mode 100644 algoperf/workloads/criteo1tb/criteo1tb_jax/models_ref.py
create mode 100644 algoperf/workloads/fastmri/fastmri_jax/models_ref.py
create mode 100644 algoperf/workloads/imagenet_resnet/imagenet_jax/models_ref.py
create mode 100644 algoperf/workloads/imagenet_vit/imagenet_jax/models_ref.py
create mode 100644 algoperf/workloads/librispeech_conformer/librispeech_jax/models_ref.py
create mode 100644 algoperf/workloads/librispeech_deepspeech/librispeech_jax/models_ref.py
create mode 100644 algoperf/workloads/ogbg/ogbg_jax/models_ref.py
create mode 100644 algoperf/workloads/wmt/wmt_jax/models_ref.py
create mode 100644 tests/dropout_fix/criteo1tb_jax/test_model_equivalence.py
create mode 100644 tests/dropout_fix/fastmri_jax/test_model_equivalence.py
create mode 100644 tests/dropout_fix/imagenet_vit_jax/test_model_equivalence.py
create mode 100644 tests/dropout_fix/librispeech_conformer_jax/test_model_equivalence.py
create mode 100644 tests/dropout_fix/librispeech_deepspeech_jax/test_model_equivalence.py
create mode 100644 tests/dropout_fix/ogbg_pytorch_jax/test_model_equivalence.py
create mode 100644 tests/dropout_fix/wmt_pytorch_jax/test_model_equivalence.py
create mode 100644 tests/test_jax_utils.py
diff --git a/algoperf/jax_utils.py b/algoperf/jax_utils.py
index 467606241..28a4ba8c9 100644
--- a/algoperf/jax_utils.py
+++ b/algoperf/jax_utils.py
@@ -62,7 +62,7 @@ class Dropout(Module):
broadcast_dims: Sequence[int] = ()
deterministic: bool | None = None
rng_collection: str = "dropout"
- legacy: bool = True
+ legacy: bool = False
@compact
def __call__(
@@ -114,7 +114,7 @@ def __call__(
broadcast_shape[dim] = 1
mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape)
mask = jnp.broadcast_to(mask, inputs.shape)
- return lax.select(mask, inputs, jnp.zeros_like(inputs))
+ return lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs))
# Utilities for debugging
diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/models_ref.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/models_ref.py
new file mode 100644
index 000000000..8406b9eb1
--- /dev/null
+++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/models_ref.py
@@ -0,0 +1,219 @@
+"""A JAX implementation of DLRM-Small."""
+
+from typing import Sequence
+
+import flax.linen as nn
+from jax import nn as jnn
+import jax.numpy as jnp
+
+
+class DLRMResNet(nn.Module):
+ """Define a DLRMResNet model.
+
+ Parameters:
+ vocab_size: the size of a single unified embedding table.
+ mlp_bottom_dims: dimensions of dense layers of the bottom mlp.
+ mlp_top_dims: dimensions of dense layers of the top mlp.
+ num_dense_features: number of dense features as the bottom mlp input.
+ embed_dim: embedding dimension.
+ """
+
+ vocab_size: int = 32 * 128 * 1024 # 4_194_304
+ num_dense_features: int = 13
+ mlp_bottom_dims: Sequence[int] = (256, 256, 256)
+ mlp_top_dims: Sequence[int] = (256, 256, 256, 256, 1)
+ embed_dim: int = 128
+ dropout_rate: float = 0.0
+ use_layer_norm: bool = False # Unused.
+ embedding_init_multiplier: float = None # Unused
+
+ @nn.compact
+ def __call__(self, x, train):
+ bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1)
+ cat_features = jnp.asarray(cat_features, dtype=jnp.int32)
+
+ # bottom mlp
+ mlp_bottom_dims = self.mlp_bottom_dims
+
+ bot_mlp_input = nn.Dense(
+ mlp_bottom_dims[0],
+ kernel_init=jnn.initializers.glorot_uniform(),
+ bias_init=jnn.initializers.normal(stddev=1.0 / mlp_bottom_dims[0]**0.5),
+ )(
+ bot_mlp_input)
+ bot_mlp_input = nn.relu(bot_mlp_input)
+
+ for dense_dim in mlp_bottom_dims[1:]:
+ x = nn.Dense(
+ dense_dim,
+ kernel_init=jnn.initializers.glorot_uniform(),
+ bias_init=jnn.initializers.normal(stddev=1.0 / dense_dim**0.5),
+ )(
+ bot_mlp_input)
+ bot_mlp_input += nn.relu(x)
+
+ base_init_fn = jnn.initializers.uniform(scale=1.0)
+ # Embedding table init and lookup for a single unified table.
+ idx_lookup = jnp.reshape(cat_features, [-1]) % self.vocab_size
+
+ def scaled_init(key, shape, dtype=jnp.float_):
+ return base_init_fn(key, shape, dtype) / jnp.sqrt(self.vocab_size)
+
+ embedding_table = self.param('embedding_table',
+ scaled_init, [self.vocab_size, self.embed_dim])
+
+ embed_features = embedding_table[idx_lookup]
+ batch_size = bot_mlp_input.shape[0]
+ embed_features = jnp.reshape(embed_features,
+ (batch_size, 26 * self.embed_dim))
+ top_mlp_input = jnp.concatenate([bot_mlp_input, embed_features], axis=1)
+ mlp_input_dim = top_mlp_input.shape[1]
+ mlp_top_dims = self.mlp_top_dims
+ num_layers_top = len(mlp_top_dims)
+ top_mlp_input = nn.Dense(
+ mlp_top_dims[0],
+ kernel_init=jnn.initializers.normal(
+ stddev=jnp.sqrt(2.0 / (mlp_input_dim + mlp_top_dims[0]))),
+ bias_init=jnn.initializers.normal(
+ stddev=jnp.sqrt(1.0 / mlp_top_dims[0])))(
+ top_mlp_input)
+ top_mlp_input = nn.relu(top_mlp_input)
+ for layer_idx, fan_out in list(enumerate(mlp_top_dims))[1:-1]:
+ fan_in = mlp_top_dims[layer_idx - 1]
+ x = nn.Dense(
+ fan_out,
+ kernel_init=jnn.initializers.normal(
+ stddev=jnp.sqrt(2.0 / (fan_in + fan_out))),
+ bias_init=jnn.initializers.normal(
+ stddev=jnp.sqrt(1.0 / mlp_top_dims[layer_idx])))(
+ top_mlp_input)
+ x = nn.relu(x)
+ if self.dropout_rate and layer_idx == num_layers_top - 2:
+ x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x)
+ top_mlp_input += x
+ # In the DLRM model the last layer width is always 1. We can hardcode that
+ # below.
+ logits = nn.Dense(
+ 1,
+ kernel_init=jnn.initializers.normal(
+ stddev=jnp.sqrt(2.0 / (mlp_top_dims[-2] + 1))),
+ bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0)))(
+ top_mlp_input)
+ return logits
+
+
+def dot_interact(concat_features):
+ """Performs feature interaction operation between dense or sparse features.
+ Input tensors represent dense or sparse features.
+ Pre-condition: The tensors have been stacked along dimension 1.
+ Args:
+ concat_features: Array of features with shape [B, n_features, feature_dim].
+ Returns:
+ activations: Array representing interacted features.
+ """
+ batch_size = concat_features.shape[0]
+
+ # Interact features, select upper or lower-triangular portion, and reshape.
+ xactions = jnp.matmul(concat_features,
+ jnp.transpose(concat_features, [0, 2, 1]))
+ feature_dim = xactions.shape[-1]
+
+ indices = jnp.array(jnp.triu_indices(feature_dim))
+ num_elems = indices.shape[1]
+ indices = jnp.tile(indices, [1, batch_size])
+ indices0 = jnp.reshape(
+ jnp.tile(jnp.reshape(jnp.arange(batch_size), [-1, 1]), [1, num_elems]),
+ [1, -1])
+ indices = tuple(jnp.concatenate((indices0, indices), 0))
+ activations = xactions[indices]
+ activations = jnp.reshape(activations, [batch_size, -1])
+ return activations
+
+
+class DlrmSmall(nn.Module):
+ """Define a DLRM-Small model.
+
+ Parameters:
+ vocab_size: vocab size of embedding table.
+ num_dense_features: number of dense features as the bottom mlp input.
+ mlp_bottom_dims: dimensions of dense layers of the bottom mlp.
+ mlp_top_dims: dimensions of dense layers of the top mlp.
+ embed_dim: embedding dimension.
+ """
+
+ vocab_size: int = 32 * 128 * 1024 # 4_194_304.
+ num_dense_features: int = 13
+ mlp_bottom_dims: Sequence[int] = (512, 256, 128)
+ mlp_top_dims: Sequence[int] = (1024, 1024, 512, 256, 1)
+ embed_dim: int = 128
+ dropout_rate: float = 0.0
+ use_layer_norm: bool = False
+ embedding_init_multiplier: float = None
+
+ @nn.compact
+ def __call__(self, x, train):
+ bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1)
+ cat_features = jnp.asarray(cat_features, dtype=jnp.int32)
+
+ # Bottom MLP.
+ for dense_dim in self.mlp_bottom_dims:
+ bot_mlp_input = nn.Dense(
+ dense_dim,
+ kernel_init=jnn.initializers.glorot_uniform(),
+ bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / dense_dim)),
+ )(
+ bot_mlp_input)
+ bot_mlp_input = nn.relu(bot_mlp_input)
+ if self.use_layer_norm:
+ bot_mlp_input = nn.LayerNorm()(bot_mlp_input)
+ bot_mlp_output = bot_mlp_input
+ batch_size = bot_mlp_output.shape[0]
+ feature_stack = jnp.reshape(bot_mlp_output,
+ [batch_size, -1, self.embed_dim])
+
+ # Embedding table look-up.
+ idx_lookup = jnp.reshape(cat_features, [-1]) % self.vocab_size
+
+ if self.embedding_init_multiplier is None:
+ scale = 1 / jnp.sqrt(self.vocab_size)
+ else:
+ scale = self.embedding_init_multiplier
+
+ def scaled_init(key, shape, dtype=jnp.float_):
+ return jnn.initializers.uniform(scale=1.0)(key, shape, dtype) * scale
+
+ embedding_table = self.param('embedding_table',
+ scaled_init, [self.vocab_size, self.embed_dim])
+
+ idx_lookup = jnp.reshape(idx_lookup, [-1])
+ embed_features = embedding_table[idx_lookup]
+ embed_features = jnp.reshape(embed_features,
+ [batch_size, -1, self.embed_dim])
+ if self.use_layer_norm:
+ embed_features = nn.LayerNorm()(embed_features)
+ feature_stack = jnp.concatenate([feature_stack, embed_features], axis=1)
+ dot_interact_output = dot_interact(concat_features=feature_stack)
+ top_mlp_input = jnp.concatenate([bot_mlp_output, dot_interact_output],
+ axis=-1)
+ mlp_input_dim = top_mlp_input.shape[1]
+ mlp_top_dims = self.mlp_top_dims
+ num_layers_top = len(mlp_top_dims)
+ for layer_idx, fan_out in enumerate(mlp_top_dims):
+ fan_in = mlp_input_dim if layer_idx == 0 else mlp_top_dims[layer_idx - 1]
+ top_mlp_input = nn.Dense(
+ fan_out,
+ kernel_init=jnn.initializers.normal(
+ stddev=jnp.sqrt(2.0 / (fan_in + fan_out))),
+ bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / fan_out)))(
+ top_mlp_input)
+ if layer_idx < (num_layers_top - 1):
+ top_mlp_input = nn.relu(top_mlp_input)
+ if self.use_layer_norm:
+ top_mlp_input = nn.LayerNorm()(top_mlp_input)
+ if (self.dropout_rate is not None and self.dropout_rate > 0.0 and
+ layer_idx == num_layers_top - 2):
+ top_mlp_input = nn.Dropout(
+ rate=self.dropout_rate, deterministic=not train)(
+ top_mlp_input)
+ logits = top_mlp_input
+ return logits
\ No newline at end of file
diff --git a/algoperf/workloads/fastmri/fastmri_jax/models_ref.py b/algoperf/workloads/fastmri/fastmri_jax/models_ref.py
new file mode 100644
index 000000000..a2d56a4b4
--- /dev/null
+++ b/algoperf/workloads/fastmri/fastmri_jax/models_ref.py
@@ -0,0 +1,220 @@
+"""Jax / Flax implementation of FastMRI U-Net.
+
+Forked from
+https://github.com/google/init2winit/blob/master/init2winit/model_lib/unet.py
+
+Original implementation:
+github.com/facebookresearch/fastMRI/blob/main/fastmri/models/unet.py
+
+Training:
+github.com/facebookresearch/fastMRI/blob/main/fastmri/pl_modules/unet_module.py
+
+Data:
+github.com/facebookresearch/fastMRI/tree/main/fastmri/data
+"""
+import functools
+from typing import Optional
+
+import flax.linen as nn
+import jax
+import jax.numpy as jnp
+
+
+def _instance_norm2d(x, axes, epsilon=1e-5):
+ # promote x to at least float32, this avoids half precision computation
+ # but preserves double or complex floating points
+ x = jnp.asarray(x, jnp.promote_types(jnp.float32, jnp.result_type(x)))
+ mean = jnp.mean(x, axes)
+ mean2 = jnp.mean(jnp.square(x), axes)
+ # mean2 - _abs_sq(mean) is not guaranteed to be non-negative due
+ # to floating point round-off errors.
+ var = jnp.maximum(0., mean2 - jnp.square(mean))
+ stats_shape = list(x.shape)
+ for axis in axes:
+ stats_shape[axis] = 1
+ mean = mean.reshape(stats_shape)
+ var = var.reshape(stats_shape)
+ y = x - mean
+ mul = jnp.sqrt(var + epsilon)
+ y /= mul
+ return y
+
+
+class UNet(nn.Module):
+ """Jax / Flax implementation of a U-Net model.
+
+ O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks
+ for biomedical image segmentation. In International Conference on Medical
+ image computing and computer-assisted intervention, pages 234–241.
+ Springer, 2015.
+
+ out_channels: Number of channels in the output to the U-Net model.
+ channels: Number of output channels of the first convolution layer.
+ num_pool_layers: Number of down-sampling and up-sampling layers.
+ dropout_rate: Dropout probability.
+ """
+ num_channels: int = 32
+ num_pool_layers: int = 4
+ out_channels = 1
+ dropout_rate: Optional[float] = 0.0 # If None, defaults to 0.0.
+ use_tanh: bool = False
+ use_layer_norm: bool = False
+
+ @nn.compact
+ def __call__(self, x, train=True):
+ dropout_rate = self.dropout_rate
+ if dropout_rate is None:
+ dropout_rate = 0.0
+
+ # pylint: disable=invalid-name
+ _ConvBlock = functools.partial(
+ ConvBlock,
+ dropout_rate=dropout_rate,
+ use_tanh=self.use_tanh,
+ use_layer_norm=self.use_layer_norm)
+ _TransposeConvBlock = functools.partial(
+ TransposeConvBlock,
+ use_tanh=self.use_tanh,
+ use_layer_norm=self.use_layer_norm)
+
+ down_sample_layers = [_ConvBlock(self.num_channels)]
+
+ ch = self.num_channels
+ for _ in range(self.num_pool_layers - 1):
+ down_sample_layers.append(_ConvBlock(ch * 2))
+ ch *= 2
+ conv = _ConvBlock(ch * 2)
+
+ up_conv = []
+ up_transpose_conv = []
+ for _ in range(self.num_pool_layers - 1):
+ up_transpose_conv.append(_TransposeConvBlock(ch))
+ up_conv.append(_ConvBlock(ch))
+ ch //= 2
+
+ up_transpose_conv.append(_TransposeConvBlock(ch))
+ up_conv.append(_ConvBlock(ch))
+
+ stack = []
+ output = jnp.expand_dims(x, axis=-1)
+
+ # apply down-sampling layers
+ for layer in down_sample_layers:
+ output = layer(output, train)
+ stack.append(output)
+ output = nn.avg_pool(output, window_shape=(2, 2), strides=(2, 2))
+
+ output = conv(output, train)
+
+ # apply up-sampling layers
+ for transpose_conv, conv in zip(up_transpose_conv, up_conv):
+ downsample_layer = stack.pop()
+ output = transpose_conv(output)
+
+ # reflect pad on the right/botton if needed to handle odd input dimensions
+ padding_right = 0
+ padding_bottom = 0
+ if output.shape[-2] != downsample_layer.shape[-2]:
+ padding_right = 1 # padding right
+ if output.shape[-3] != downsample_layer.shape[-3]:
+ padding_bottom = 1 # padding bottom
+
+ if padding_right or padding_bottom:
+ padding = ((0, 0), (0, padding_bottom), (0, padding_right), (0, 0))
+ output = jnp.pad(output, padding, mode='reflect')
+
+ output = jnp.concatenate((output, downsample_layer), axis=-1)
+ output = conv(output, train)
+
+ output = nn.Conv(
+ self.out_channels, kernel_size=(1, 1), strides=(1, 1))(
+ output)
+ return output.squeeze(-1)
+
+
+class ConvBlock(nn.Module):
+ """A Convolutional Block.
+ out_channels: Number of channels in the output.
+ dropout_rate: Dropout probability.
+ """
+ out_channels: int
+ dropout_rate: float
+ use_tanh: bool
+ use_layer_norm: bool
+
+ @nn.compact
+ def __call__(self, x, train=True):
+ """Forward function.
+ Note: Pytorch is NCHW and jax/flax is NHWC.
+ Args:
+ x: Input 4D tensor of shape `(N, H, W, in_channels)`.
+ train: deterministic or not (use init2winit naming).
+ Returns:
+ jnp.array: Output tensor of shape `(N, H, W, out_channels)`.
+ """
+ x = nn.Conv(
+ features=self.out_channels,
+ kernel_size=(3, 3),
+ strides=(1, 1),
+ use_bias=False)(
+ x)
+ if self.use_layer_norm:
+ x = nn.LayerNorm(reduction_axes=(1, 2, 3))(x)
+ else:
+ # DO NOT SUBMIT check that this comment edit is correct
+ # InstanceNorm2d was run with no learnable params in reference code
+ # so this is a simple normalization along spatial dims.
+ x = _instance_norm2d(x, (1, 2))
+ if self.use_tanh:
+ activation_fn = nn.tanh
+ else:
+ activation_fn = functools.partial(jax.nn.leaky_relu, negative_slope=0.2)
+ x = activation_fn(x)
+ # Ref code uses dropout2d which applies the same mask for the entire channel
+ # Replicated by using broadcast dims to have the same filter on HW
+ x = nn.Dropout(
+ self.dropout_rate, broadcast_dims=(1, 2), deterministic=not train)(
+ x)
+ x = nn.Conv(
+ features=self.out_channels,
+ kernel_size=(3, 3),
+ strides=(1, 1),
+ use_bias=False)(
+ x)
+ if self.use_layer_norm:
+ x = nn.LayerNorm(reduction_axes=(1, 2, 3))(x)
+ else:
+ x = _instance_norm2d(x, (1, 2))
+ x = activation_fn(x)
+ x = nn.Dropout(
+ self.dropout_rate, broadcast_dims=(1, 2), deterministic=not train)(
+ x)
+ return x
+
+
+class TransposeConvBlock(nn.Module):
+ """A Transpose Convolutional Block.
+ out_channels: Number of channels in the output.
+ """
+ out_channels: int
+ use_tanh: bool
+ use_layer_norm: bool
+
+ @nn.compact
+ def __call__(self, x):
+ """Forward function.
+ Args:
+ x: Input 4D tensor of shape `(N, H, W, in_channels)`.
+ Returns:
+ jnp.array: Output tensor of shape `(N, H*2, W*2, out_channels)`.
+ """
+ x = nn.ConvTranspose(
+ self.out_channels, kernel_size=(2, 2), strides=(2, 2), use_bias=False)(
+ x)
+ x = _instance_norm2d(x, (1, 2))
+ if self.use_tanh:
+ activation_fn = nn.tanh
+ else:
+ activation_fn = functools.partial(jax.nn.leaky_relu, negative_slope=0.2)
+ x = activation_fn(x)
+ return x
\ No newline at end of file
diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/models_ref.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/models_ref.py
new file mode 100644
index 000000000..357dadc13
--- /dev/null
+++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/models_ref.py
@@ -0,0 +1,132 @@
+"""Jax implementation of ResNet V1.
+
+Adapted from Flax example:
+https://github.com/google/flax/blob/main/examples/imagenet/models.py.
+"""
+
+import functools
+from typing import Any, Callable, Optional, Tuple
+
+from flax import linen as nn
+import jax.numpy as jnp
+
+from algoperf import spec
+
+ModuleDef = nn.Module
+
+
+class ResNetBlock(nn.Module):
+ """ResNet block."""
+ filters: int
+ conv: ModuleDef
+ norm: ModuleDef
+ act: Callable
+ strides: Tuple[int, int] = (1, 1)
+ bn_init_scale: float = 0.
+
+ @nn.compact
+ def __call__(self, x: spec.Tensor) -> spec.Tensor:
+ residual = x
+ y = self.conv(self.filters, (3, 3), self.strides)(x)
+ y = self.norm()(y)
+ y = self.act(y)
+ y = self.conv(self.filters, (3, 3))(y)
+ y = self.norm(scale_init=nn.initializers.constant(self.bn_init_scale))(y)
+
+ if residual.shape != y.shape or self.strides != (1, 1):
+ residual = self.conv(
+ self.filters, (1, 1), self.strides, name='Conv_proj')(
+ residual)
+ residual = self.norm(name='BatchNorm_proj')(residual)
+
+ return self.act(residual + y)
+
+
+class BottleneckResNetBlock(nn.Module):
+ """Bottleneck ResNet block."""
+ filters: int
+ conv: ModuleDef
+ norm: ModuleDef
+ act: Callable
+ strides: Tuple[int, int] = (1, 1)
+ bn_init_scale: Optional[float] = None
+
+ @nn.compact
+ def __call__(self, x: spec.Tensor) -> spec.Tensor:
+ residual = x
+ y = self.conv(self.filters, (1, 1))(x)
+ y = self.norm()(y)
+ y = self.act(y)
+ y = self.conv(self.filters, (3, 3), self.strides)(y)
+ y = self.norm()(y)
+ y = self.act(y)
+ y = self.conv(self.filters * 4, (1, 1))(y)
+ y = self.norm(scale_init=nn.initializers.constant(self.bn_init_scale))(y)
+
+ if residual.shape != y.shape or self.strides != (1, 1):
+ residual = self.conv(
+ self.filters * 4, (1, 1), self.strides, name='Conv_proj')(
+ residual)
+ residual = self.norm(name='BatchNorm_proj')(residual)
+
+ return self.act(residual + y)
+
+
+class ResNet(nn.Module):
+ stage_sizes: Tuple[int]
+ block_cls: ModuleDef
+ num_classes: int
+ num_filters: int = 64
+ dtype: Any = jnp.float32
+ act: Callable = nn.relu
+ bn_init_scale: float = 0.
+
+ @nn.compact
+ def __call__(self,
+ x: spec.Tensor,
+ update_batch_norm: bool = True,
+ use_running_average_bn: Optional[bool] = None) -> spec.Tensor:
+ conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype)
+
+ # Preserve default behavior for backwards compatibility
+ if use_running_average_bn is None:
+ use_running_average_bn = not update_batch_norm
+ norm = functools.partial(
+ nn.BatchNorm,
+ use_running_average=use_running_average_bn,
+ momentum=0.9,
+ epsilon=1e-5,
+ dtype=self.dtype)
+
+ x = conv(
+ self.num_filters, (7, 7), (2, 2),
+ padding=[(3, 3), (3, 3)],
+ name='Conv_init')(
+ x)
+ x = norm(name='BatchNorm_init')(x)
+ x = self.act(x)
+ x = nn.max_pool(x, (3, 3), strides=(2, 2), padding=((1, 1), (1, 1)))
+ for i, block_size in enumerate(self.stage_sizes):
+ for j in range(block_size):
+ strides = (2, 2) if i > 0 and j == 0 else (1, 1)
+ x = self.block_cls(
+ self.num_filters * 2**i,
+ strides=strides,
+ conv=conv,
+ norm=norm,
+ act=self.act,
+ bn_init_scale=self.bn_init_scale)(
+ x)
+ x = jnp.mean(x, axis=(1, 2))
+ x = nn.Dense(
+ self.num_classes,
+ kernel_init=nn.initializers.normal(),
+ dtype=self.dtype)(
+ x)
+ return x
+
+
+ResNet18 = functools.partial(
+ ResNet, stage_sizes=(2, 2, 2, 2), block_cls=ResNetBlock)
+ResNet50 = functools.partial(
+ ResNet, stage_sizes=(3, 4, 6, 3), block_cls=BottleneckResNetBlock)
\ No newline at end of file
diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/models_ref.py b/algoperf/workloads/imagenet_vit/imagenet_jax/models_ref.py
new file mode 100644
index 000000000..beb8a2eb8
--- /dev/null
+++ b/algoperf/workloads/imagenet_vit/imagenet_jax/models_ref.py
@@ -0,0 +1,235 @@
+"""Jax implementation of refactored and simplified ViT.
+
+Forked from:
+https://github.com/google/init2winit/blob/master/init2winit/model_lib/vit.py,
+originally from https://github.com/google/big_vision with modifications noted.
+"""
+
+from typing import Optional, Sequence, Union
+
+from flax import linen as nn
+import jax.numpy as jnp
+
+from algoperf import spec
+
+
+def posemb_sincos_2d(h: int,
+ w: int,
+ width: int,
+ temperature: int = 10_000.,
+ dtype: jnp.dtype = jnp.float32) -> spec.Tensor:
+ """Follows the MoCo v3 logic."""
+ y, x = jnp.mgrid[:h, :w] #pylint: disable=unpacking-non-sequence
+
+ if width % 4 != 0:
+ raise ValueError('Width must be mult of 4 for sincos posemb.')
+ omega = jnp.arange(width // 4) / (width // 4 - 1)
+ omega = 1. / (temperature**omega)
+ y = jnp.einsum('m,d->md', y.flatten(), omega)
+ x = jnp.einsum('m,d->md', x.flatten(), omega)
+ pe = jnp.concatenate([jnp.sin(x), jnp.cos(x), jnp.sin(y), jnp.cos(y)], axis=1)
+ return jnp.asarray(pe, dtype)[None, :, :]
+
+
+class MlpBlock(nn.Module):
+ """Transformer MLP / feed-forward block."""
+ mlp_dim: Optional[int] = None # Defaults to 4x input dim.
+ use_glu: bool = False
+ dropout_rate: float = 0.0
+
+ @nn.compact
+ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor:
+ """Applies Transformer MlpBlock module."""
+ inits = {
+ 'kernel_init': nn.initializers.xavier_uniform(),
+ 'bias_init': nn.initializers.normal(stddev=1e-6),
+ }
+
+ d = x.shape[2]
+ x = nn.Dense(self.mlp_dim or 4 * d, **inits)(x)
+ x = nn.gelu(x)
+
+ if self.use_glu:
+ y = nn.Dense(self.mlp_dim, **inits)(x)
+ x = x * y
+
+ x = nn.Dropout(rate=self.dropout_rate)(x, train)
+ x = nn.Dense(d, **inits)(x)
+ return x
+
+
+class Encoder1DBlock(nn.Module):
+ """Single transformer encoder block (MHSA + MLP)."""
+ mlp_dim: Optional[int] = None # Defaults to 4x input dim.
+ num_heads: int = 12
+ use_glu: bool = False
+ use_post_layer_norm: bool = False
+ dropout_rate: float = 0.0
+
+ @nn.compact
+ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor:
+ if not self.use_post_layer_norm:
+ y = nn.LayerNorm(name='LayerNorm_0')(x)
+ y = nn.MultiHeadDotProductAttention(
+ num_heads=self.num_heads,
+ kernel_init=nn.initializers.xavier_uniform(),
+ deterministic=train,
+ name='MultiHeadDotProductAttention_1')(
+ y)
+ y = nn.Dropout(rate=self.dropout_rate)(y, train)
+ x = x + y
+
+ y = nn.LayerNorm(name='LayerNorm_2')(x)
+ y = MlpBlock(
+ mlp_dim=self.mlp_dim,
+ use_glu=self.use_glu,
+ dropout_rate=self.dropout_rate,
+ name='MlpBlock_3')(y, train)
+ y = nn.Dropout(rate=self.dropout_rate)(y, train)
+ x = x + y
+ else:
+ y = x
+ y = nn.MultiHeadDotProductAttention(
+ num_heads=self.num_heads,
+ kernel_init=nn.initializers.xavier_uniform(),
+ deterministic=train,
+ name='MultiHeadDotProductAttention_1')(
+ y)
+ y = nn.Dropout(rate=self.dropout_rate)(y, train)
+ x = x + y
+ x = nn.LayerNorm(name='LayerNorm_0')(x)
+
+ y = x
+ y = MlpBlock(
+ mlp_dim=self.mlp_dim,
+ use_glu=self.use_glu,
+ dropout_rate=self.dropout_rate,
+ name='MlpBlock_3')(y, train)
+ y = nn.Dropout(rate=self.dropout_rate)(y, train)
+ x = x + y
+ x = nn.LayerNorm(name='LayerNorm_2')(x)
+
+ return x
+
+
+class Encoder(nn.Module):
+ """Transformer Model Encoder for sequence to sequence translation."""
+ depth: int
+ mlp_dim: Optional[int] = None # Defaults to 4x input dim.
+ num_heads: int = 12
+ dropout_rate: float = 0.0
+ use_glu: bool = False
+ use_post_layer_norm: bool = False
+
+ @nn.compact
+ def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor:
+ # Input Encoder
+ for lyr in range(self.depth):
+ block = Encoder1DBlock(
+ name=f'encoderblock_{lyr}',
+ mlp_dim=self.mlp_dim,
+ num_heads=self.num_heads,
+ use_glu=self.use_glu,
+ use_post_layer_norm=self.use_post_layer_norm,
+ dropout_rate=self.dropout_rate)
+ x = block(x, train)
+ if not self.use_post_layer_norm:
+ return nn.LayerNorm(name='encoder_layernorm')(x)
+ else:
+ return x
+
+
+class MAPHead(nn.Module):
+ """Multihead Attention Pooling."""
+ mlp_dim: Optional[int] = None # Defaults to 4x input dim
+ num_heads: int = 12
+
+ @nn.compact
+ def __call__(self, x):
+ n, _, d = x.shape
+ probe = self.param('probe',
+ nn.initializers.xavier_uniform(), (1, 1, d),
+ x.dtype)
+ probe = jnp.tile(probe, [n, 1, 1])
+
+ x = nn.MultiHeadDotProductAttention(
+ num_heads=self.num_heads,
+ use_bias=True,
+ kernel_init=nn.initializers.xavier_uniform())(probe, x)
+
+ y = nn.LayerNorm()(x)
+ x = x + MlpBlock(mlp_dim=self.mlp_dim)(y)
+ return x[:, 0]
+
+
+class ViT(nn.Module):
+ """ViT model."""
+
+ num_classes: int = 1000
+ patch_size: Sequence[int] = (16, 16)
+ width: int = 768
+ depth: int = 12
+ mlp_dim: Optional[int] = None # Defaults to 4x input dim.
+ num_heads: int = 12
+ rep_size: Union[int, bool] = True
+ dropout_rate: Optional[float] = 0.0 # If None, defaults to 0.0.
+ reinit: Optional[Sequence[str]] = None
+ head_zeroinit: bool = True
+ use_glu: bool = False
+ use_post_layer_norm: bool = False
+ use_map: bool = False
+
+ def get_posemb(self,
+ seqshape: tuple,
+ width: int,
+ dtype: jnp.dtype = jnp.float32) -> spec.Tensor:
+ return posemb_sincos_2d(*seqshape, width, dtype=dtype)
+
+ @nn.compact
+ def __call__(self, x: spec.Tensor, *, train: bool = False) -> spec.Tensor:
+ # Patch extraction
+ x = nn.Conv(
+ self.width,
+ self.patch_size,
+ strides=self.patch_size,
+ padding='VALID',
+ name='conv_patch_extract')(
+ x)
+
+ n, h, w, c = x.shape
+ x = jnp.reshape(x, [n, h * w, c])
+
+ # Add posemb before adding extra token.
+ x = x + self.get_posemb((h, w), c, x.dtype)
+
+ dropout_rate = self.dropout_rate
+ if dropout_rate is None:
+ dropout_rate = 0.0
+ x = nn.Dropout(rate=dropout_rate)(x, not train)
+
+ x = Encoder(
+ depth=self.depth,
+ mlp_dim=self.mlp_dim,
+ num_heads=self.num_heads,
+ use_glu=self.use_glu,
+ use_post_layer_norm=self.use_post_layer_norm,
+ dropout_rate=dropout_rate,
+ name='Transformer')(
+ x, train=not train)
+
+ if self.use_map:
+ x = MAPHead(num_heads=self.num_heads, mlp_dim=self.mlp_dim)(x)
+ else:
+ x = jnp.mean(x, axis=1)
+
+ if self.rep_size:
+ rep_size = self.width if self.rep_size is True else self.rep_size
+ hid = nn.Dense(rep_size, name='pre_logits')
+ x = nn.tanh(hid(x))
+
+ if self.num_classes:
+ kw = {'kernel_init': nn.initializers.zeros} if self.head_zeroinit else {}
+ head = nn.Dense(self.num_classes, name='head', **kw)
+ x = head(x)
+
+ return x
\ No newline at end of file
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/models_ref.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/models_ref.py
new file mode 100644
index 000000000..969d9423c
--- /dev/null
+++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/models_ref.py
@@ -0,0 +1,712 @@
+r"""Conformer.
+
+This model uses a conformer network to convert speech to text.
+paper : https://arxiv.org/abs/2005.08100
+
+high-level overview of Conformer encoder layer.
+
+ x = x + 0.5 * FeedForward(x)
+ x = x + MHSA(x)
+ x = x + ConvolutionBlock(x)
+ x = x + 0.5 * FeedForward(x)
+ y = layer_norm(x)
+"""
+
+import functools
+import math
+from typing import Any, List, Optional
+
+from flax import linen as nn
+from flax import struct
+import jax
+import jax.numpy as jnp
+import numpy as np
+
+from algoperf.workloads.librispeech_conformer.librispeech_jax import \
+ librispeech_preprocessor as preprocessor
+from algoperf.workloads.librispeech_conformer.librispeech_jax import \
+ spectrum_augmenter
+
+
+@struct.dataclass
+class ConformerConfig:
+ """Global hyperparameters used to minimize obnoxious kwarg plumbing."""
+ vocab_size: int = 1024
+ dtype: Any = jnp.float32
+ encoder_dim: int = 512
+ num_attention_heads: int = 8
+ num_encoder_layers: int = 4
+ attention_dropout_rate: float = 0.0
+ # If None, defaults to 0.1.
+ attention_residual_dropout_rate: Optional[float] = 0.1
+ # If None, defaults to 0.0.
+ conv_residual_dropout_rate: Optional[float] = 0.0
+ feed_forward_dropout_rate: float = 0.0
+ # If None, defaults to 0.1.
+ feed_forward_residual_dropout_rate: Optional[float] = 0.1
+ convolution_kernel_size: int = 5
+ feed_forward_expansion_factor: int = 4
+ freq_mask_count: int = 2
+ freq_mask_max_bins: int = 27
+ time_mask_count: int = 10
+ time_mask_max_frames: int = 40
+ time_mask_max_ratio: float = 0.05
+ time_masks_per_frame: float = 0.0
+ use_dynamic_time_mask_max_frames: bool = True
+ # If None, defaults to 0.1.
+ input_dropout_rate: Optional[float] = 0.1
+ batch_norm_momentum: float = 0.999
+ batch_norm_epsilon: float = 0.001
+ use_specaug: bool = True
+ attention_temperature: float = 1.0
+ activation_function_name: str = 'swish'
+ use_post_layer_norm: bool = True
+
+
+class LayerNorm(nn.Module):
+ """Module implementing layer normalization.
+
+ This implementation is same as in this paper:
+ https://arxiv.org/pdf/1607.06450.pdf.
+
+ note: we multiply normalized inputs by (1 + scale) and initialize scale to
+ zeros, this differs from default flax implementation of multiplying by scale
+ and initializing to ones.
+ """
+ dim: int = 0
+ epsilon: float = 1e-6
+
+ def setup(self):
+ self.scale = self.param('scale', nn.initializers.zeros, [self.dim])
+ self.bias = self.param('bias', nn.initializers.zeros, [self.dim])
+
+ @nn.compact
+ def __call__(self, inputs):
+ mean = jnp.mean(inputs, axis=[-1], keepdims=True)
+ var = jnp.mean(jnp.square(inputs - mean), axis=[-1], keepdims=True)
+
+ normed_inputs = (inputs - mean) * jax.lax.rsqrt(var + self.epsilon)
+ normed_inputs *= (1 + self.scale)
+ normed_inputs += self.bias
+
+ return normed_inputs
+
+
+class Subsample(nn.Module):
+ """Module to perform strided convolution in order to subsample inputs.
+
+ Attributes:
+ encoder_dim: model dimension of conformer.
+ input_dropout_rate: dropout rate for inputs.
+ """
+ encoder_dim: int = 0
+ input_dropout_rate: float = 0.0
+
+ @nn.compact
+ def __call__(self, inputs, input_paddings, train):
+ output_paddings = input_paddings
+ outputs = jnp.expand_dims(inputs, axis=-1)
+
+ outputs, output_paddings = Conv2dSubsampling(
+ input_channels=1, output_channels=self.encoder_dim)(
+ outputs, output_paddings)
+
+ outputs, output_paddings = Conv2dSubsampling(
+ input_channels=self.encoder_dim,
+ output_channels=self.encoder_dim)(outputs, output_paddings)
+
+ batch_size, subsampled_lengths, subsampled_dims, channels = outputs.shape
+
+ outputs = jnp.reshape(
+ outputs, (batch_size, subsampled_lengths, subsampled_dims * channels))
+
+ outputs = nn.Dense(
+ self.encoder_dim,
+ use_bias=True,
+ kernel_init=nn.initializers.xavier_uniform())(
+ outputs)
+
+ outputs = outputs + AddPositionalEmbedding(embedding_dim=self.encoder_dim)(
+ seq_length=outputs.shape[1])
+
+ outputs = nn.Dropout(
+ rate=self.input_dropout_rate, deterministic=not train)(
+ outputs)
+
+ return outputs, output_paddings
+
+
+class Conv2dSubsampling(nn.Module):
+ """Helper module used in Subsample layer.
+
+ 1) Performs strided convolution over inputs and then applies non-linearity.
+ 2) Also performs strided convolution over input_paddings to return the correct
+ paddings for downstream layers.
+ """
+ input_channels: int = 0
+ output_channels: int = 0
+ filter_stride: List[int] = (2, 2)
+ padding: str = 'SAME'
+
+ def setup(self):
+ self.filter_shape = (3, 3, self.input_channels, self.output_channels)
+ self.kernel = self.param('kernel',
+ nn.initializers.xavier_uniform(),
+ self.filter_shape)
+ self.bias = self.param(
+ 'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels)
+
+ @nn.compact
+ def __call__(self, inputs, paddings):
+ # Computing strided convolution to subsample inputs.
+ feature_group_count = inputs.shape[3] // self.filter_shape[2]
+ outputs = jax.lax.conv_general_dilated(
+ lhs=inputs,
+ rhs=self.kernel,
+ window_strides=self.filter_stride,
+ padding=self.padding,
+ rhs_dilation=(1, 1),
+ dimension_numbers=('NHWC', 'HWIO', 'NHWC'),
+ feature_group_count=feature_group_count)
+
+ outputs += jnp.reshape(self.bias, (1,) * (outputs.ndim - 1) + (-1,))
+ outputs = nn.relu(outputs)
+
+ # Computing correct paddings post input convolution.
+ input_length = paddings.shape[1]
+ stride = self.filter_stride[0]
+
+ pad_len = (input_length + stride - 1) // stride * stride - input_length
+ out_padding = jax.lax.conv_general_dilated(
+ lhs=paddings[:, :, None],
+ rhs=jnp.ones([1, 1, 1]),
+ window_strides=self.filter_stride[:1],
+ padding=[(0, pad_len)],
+ dimension_numbers=('NHC', 'HIO', 'NHC'))
+ out_padding = jnp.squeeze(out_padding, axis=-1)
+
+ # Mask outputs by correct paddings to ensure padded elements in inputs map
+ # to padded value in outputs.
+ outputs = outputs * \
+ (1.0 - jnp.expand_dims(jnp.expand_dims(out_padding, -1), -1))
+ return outputs, out_padding
+
+
+class FeedForwardModule(nn.Module):
+ """Feedforward block of conformer layer.
+ """
+ config: ConformerConfig
+
+ @nn.compact
+ def __call__(self, inputs, padding_mask=None, train=False):
+ config = self.config
+
+ inputs = LayerNorm(dim=config.encoder_dim)(inputs)
+
+ inputs = nn.Dense(
+ config.encoder_dim * config.feed_forward_expansion_factor,
+ use_bias=True,
+ kernel_init=nn.initializers.xavier_uniform())(
+ inputs)
+ if config.activation_function_name == 'swish':
+ activation_fn = nn.swish
+ elif config.activation_function_name == 'gelu':
+ activation_fn = nn.gelu
+ else:
+ raise ValueError('Only "swish" and "gelu" are supported '
+ 'config.activation_function_name values, recieved '
+ f'{config.activation_function_name}')
+ inputs = activation_fn(inputs)
+ inputs = nn.Dropout(rate=config.feed_forward_dropout_rate)(
+ inputs, deterministic=not train)
+
+ inputs = inputs * padding_mask
+
+ inputs = nn.Dense(
+ config.encoder_dim,
+ use_bias=True,
+ kernel_init=nn.initializers.xavier_uniform())(
+ inputs)
+ inputs = inputs * padding_mask
+
+ if config.feed_forward_residual_dropout_rate is None:
+ feed_forward_residual_dropout_rate = 0.1
+ else:
+ feed_forward_residual_dropout_rate = (
+ config.feed_forward_residual_dropout_rate)
+ inputs = nn.Dropout(rate=feed_forward_residual_dropout_rate)(
+ inputs, deterministic=not train)
+
+ return inputs
+
+
+class AddPositionalEmbedding(nn.Module):
+ """Adds (optionally learned) positional embeddings to the inputs.
+
+ Attributes:
+ max_len: maximum possible length for the input
+ posemb_init: positional embedding initializer
+ """
+ min_timescale: int = 1
+ max_timescale: int = 10_000
+ embedding_dim: int = 512
+
+ @nn.compact
+ def __call__(self, seq_length):
+ position = jnp.arange(seq_length, dtype=jnp.float32)[jnp.newaxis, :]
+ num_timescales = self.embedding_dim // 2
+ log_timescale_increment = (
+ math.log(float(self.max_timescale) / float(self.min_timescale)) /
+ jnp.maximum(jnp.asarray(num_timescales, dtype=jnp.float32) - 1, 1))
+ inv_timescales = self.min_timescale * jnp.exp(
+ jnp.arange(num_timescales, dtype=jnp.float32) *
+ -log_timescale_increment)
+ scaled_time = (
+ position[:, :, jnp.newaxis] *
+ inv_timescales[jnp.newaxis, jnp.newaxis, :])
+ signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)],
+ axis=2).astype(jnp.float32)
+ # Force usage of `np` rather than `jnp` to compute static values at trace
+ # time.
+ signal = jnp.pad(signal,
+ [[0, 0], [0, 0], [0, np.mod(self.embedding_dim, 2)]])
+ return signal
+
+
+# Adapted from lingvo attention layer for query scaling
+# https://github.com/tensorflow/lingvo/blob/7de4ca8fff3cb28c2ecb21bbd7b02a964ce727f7/lingvo/jax/layers/attentions.py#L201
+class QueryScaler(nn.Module):
+ """A layer to scale individual dims of the query attention matrix."""
+ dim: int = 0
+
+ def setup(self):
+ self.scale = self.param('scale', nn.initializers.zeros, [self.dim])
+
+ @nn.compact
+ def __call__(self, inputs):
+ inputs_shape = inputs.shape
+ if inputs_shape[-1] != self.dim:
+ raise ValueError('QueryScaler expects inputs to have'
+ ' same last dimension as scaling param.')
+
+ # 1.0/jax.nn.softplus(0.0) = 1.442695041. Hard code this number so that we
+ # can avoid unnecessary XLA op fusion mess on TPU.
+ r_softplus_0 = 1.442695041
+
+ scale = jnp.array(r_softplus_0, dtype=inputs.dtype)
+ scale *= jax.nn.softplus(self.scale)
+
+ return inputs * scale
+
+
+# Modifying flax linen default dot product attention function to add
+# query scaling, reference to original function here :
+# https://github.com/google/flax/blob/a9af38085a7a49b571cf37d375060fd683e74972/flax/linen/attention.py#L121
+def dot_product_attention(query,
+ key,
+ value,
+ bias=None,
+ mask=None,
+ broadcast_dropout=True,
+ dropout_rng=None,
+ dropout_rate=0.,
+ deterministic=False,
+ dtype=jnp.float32,
+ precision=None,
+ temperature=1.0):
+ """Computes dot-product attention given query, key, and value.
+
+ This is the core function for applying attention based on
+ https://arxiv.org/abs/1706.03762. It's slightly modified to add query scaling.
+ It calculates the attention weights given query and key and combines the
+ values using the attention weights.
+
+ Note: query, key, value needn't have any batch dimensions.
+
+ Args:
+ query: queries for calculating attention with shape of
+ `[batch..., q_length, num_heads, qk_depth_per_head]`.
+ key: keys for calculating attention with shape of
+ `[batch..., kv_length, num_heads, qk_depth_per_head]`.
+ value: values to be used in attention with shape of
+ `[batch..., kv_length, num_heads, v_depth_per_head]`.
+ bias: bias for the attention weights. This should be broadcastable to the
+ shape `[batch..., num_heads, q_length, kv_length]`.
+ This can be used for incorporating causal masks, padding masks,
+ proximity bias, etc.
+ mask: mask for the attention weights. This should be broadcastable to the
+ shape `[batch..., num_heads, q_length, kv_length]`.
+ This can be used for incorporating causal masks.
+ Attention weights are masked out if their corresponding mask value
+ is `False`.
+ broadcast_dropout: bool: use a broadcasted dropout along batch dims.
+ dropout_rng: JAX PRNGKey: to be used for dropout
+ dropout_rate: dropout rate
+ deterministic: bool, deterministic or not (to apply dropout)
+ dtype: the dtype of the computation (default: float32)
+ precision: numerical precision of the computation see `jax.lax.Precision`
+ for details.
+
+ Returns:
+ Output of shape `[batch..., q_length, num_heads, v_depth_per_head]`.
+ """
+ assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.'
+ assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], (
+ 'q, k, v batch dims must match.')
+ assert query.shape[-2] == key.shape[-2] == value.shape[-2], (
+ 'q, k, v num_heads must match.')
+ assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.'
+
+ # compute attention weights
+ query = QueryScaler(dim=query.shape[-1])(query)
+ attn_weights = nn.attention.dot_product_attention_weights(
+ query,
+ key,
+ bias,
+ mask,
+ broadcast_dropout,
+ dropout_rng,
+ dropout_rate,
+ deterministic,
+ dtype,
+ precision)
+
+ # return weighted sum over values for each query position
+ return jnp.einsum(
+ '...hqk,...khd->...qhd', attn_weights, value,
+ precision=precision) * temperature
+
+
+class MultiHeadedSelfAttention(nn.Module):
+ """Self attention sub-layer used in the Conformer layer.
+
+ Input is first normalized using layer norm. Output is processed using
+ multi-headed attention.
+
+ Note: this attention implementation uses a learned scale parameter to scale
+ query matrix before passing it to flax attention module.
+ """
+ config: ConformerConfig = None
+
+ @nn.compact
+ def __call__(self, inputs, paddings, train):
+ config = self.config
+ mask_paddings = 1 - paddings
+ attention_mask = nn.make_attention_mask(
+ mask_paddings > 0, mask_paddings > 0, dtype=jnp.float32)
+
+ inputs = LayerNorm(dim=config.encoder_dim)(inputs)
+ attention_fn = functools.partial(
+ dot_product_attention, temperature=config.attention_temperature)
+ result = nn.MultiHeadDotProductAttention(
+ num_heads=config.num_attention_heads,
+ qkv_features=config.encoder_dim,
+ decode=False,
+ dtype=config.dtype,
+ kernel_init=nn.initializers.xavier_uniform(),
+ bias_init=nn.initializers.zeros,
+ use_bias=True,
+ broadcast_dropout=False,
+ attention_fn=attention_fn,
+ dropout_rate=config.attention_dropout_rate,
+ deterministic=not train)(
+ inputs_q=inputs, mask=attention_mask)
+
+ if config.attention_residual_dropout_rate is None:
+ attention_residual_dropout_rate = 0.1
+ else:
+ attention_residual_dropout_rate = config.attention_residual_dropout_rate
+ result = nn.Dropout(
+ rate=attention_residual_dropout_rate, deterministic=not train)(
+ result)
+
+ return result
+
+
+class BatchNorm(nn.Module):
+ """Implements batch norm respecting input paddings.
+
+ This implementation takes into account input padding by masking inputs before
+ computing mean and variance.
+
+ This is inspired by lingvo jax implementation of BatchNorm:
+ https://github.com/tensorflow/lingvo/blob/84b85514d7ad3652bc9720cb45acfab08604519b/lingvo/jax/layers/normalizations.py#L92
+
+ and the corresponding defaults for momentum and epsilon have been copied over
+ from lingvo.
+ """
+ config: ConformerConfig
+
+ def setup(self):
+ dim = self.config.encoder_dim
+ dtype = self.config.dtype
+
+ self.ra_mean = self.variable('batch_stats',
+ 'mean',
+ lambda s: jnp.zeros(s, dtype),
+ dim)
+ self.ra_var = self.variable('batch_stats',
+ 'var',
+ lambda s: jnp.ones(s, dtype),
+ dim)
+
+ self.gamma = self.param('scale', nn.initializers.zeros, dim, dtype)
+ self.beta = self.param('bias', nn.initializers.zeros, dim, dtype)
+
+ @nn.compact
+ def __call__(self,
+ inputs,
+ input_paddings,
+ update_batch_norm,
+ use_running_average_bn):
+ rank = inputs.ndim
+ reduce_over_dims = list(range(0, rank - 1))
+
+ padding = jnp.expand_dims(input_paddings, -1)
+ momentum = self.config.batch_norm_momentum
+ epsilon = self.config.batch_norm_epsilon
+
+ if use_running_average_bn:
+ mean = self.ra_mean.value
+ var = self.ra_var.value
+
+ else:
+ # compute batch statistics
+ mask = 1.0 - padding
+ sum_v = jnp.sum(inputs * mask, axis=reduce_over_dims, keepdims=True)
+ count_v = jnp.sum(
+ jnp.ones_like(inputs) * mask, axis=reduce_over_dims, keepdims=True)
+
+ count_v = jnp.maximum(count_v, 1.0)
+ mean = sum_v / count_v
+
+ sum_vv = jnp.sum(
+ (inputs - mean) * (inputs - mean) * mask,
+ axis=reduce_over_dims,
+ keepdims=True)
+
+ var = sum_vv / count_v
+
+ if update_batch_norm:
+ self.ra_mean.value = momentum * \
+ self.ra_mean.value + (1 - momentum) * mean
+ self.ra_var.value = momentum * \
+ self.ra_var.value + (1 - momentum) * var
+
+ inv = (1 + self.gamma) / jnp.sqrt(var + epsilon)
+ bn_output = (inputs - mean) * inv + self.beta
+ bn_output *= 1.0 - padding
+
+ return bn_output
+
+
+class ConvolutionBlock(nn.Module):
+ r"""Convolution block in conformer layer.
+
+ architecture:
+
+ input # (batch, time, hidden_dim)
+ |
+ layer_norm(.) # (batch, time, hidden_dim)
+ dense(.), dense(.) # (batch, time, 2 * hidden_dim)
+ | /
+ glu(.) # (batch, time, hidden_dim)
+ depthwise_conv1d(.)
+ batch_norm(.)
+ act(.)
+ |
+ dense(.)
+ dropout(.)
+ |
+ output
+ """
+ config: ConformerConfig
+
+ @nn.compact
+ def __call__(self,
+ inputs,
+ input_paddings,
+ train,
+ update_batch_norm,
+ use_running_average_bn):
+ config = self.config
+ inputs = LayerNorm(dim=config.encoder_dim)(inputs)
+
+ input_gated1 = nn.Dense(
+ config.encoder_dim,
+ kernel_init=nn.initializers.xavier_uniform(),
+ use_bias=True)(
+ inputs)
+
+ input_gated2 = nn.Dense(
+ config.encoder_dim,
+ kernel_init=nn.initializers.xavier_uniform(),
+ use_bias=True)(
+ inputs)
+
+ inputs = input_gated1 * jax.nn.sigmoid(input_gated2)
+ inputs = inputs * (1 - jnp.expand_dims(input_paddings, -1))
+
+ inputs = nn.Conv(
+ features=config.encoder_dim,
+ kernel_size=(config.convolution_kernel_size,),
+ strides=(1,),
+ padding='SAME',
+ feature_group_count=config.encoder_dim,
+ use_bias=False,
+ kernel_init=nn.initializers.xavier_uniform())(
+ inputs)
+
+ inputs = BatchNorm(config)(inputs,
+ input_paddings,
+ update_batch_norm,
+ use_running_average_bn)
+ if config.activation_function_name == 'swish':
+ activation_fn = nn.swish
+ elif config.activation_function_name == 'gelu':
+ activation_fn = nn.gelu
+ else:
+ raise ValueError('Only "swish" and "gelu" are supported '
+ 'config.activation_function_name values, recieved '
+ f'{config.activation_function_name}')
+ inputs = activation_fn(inputs)
+ inputs = nn.Dense(
+ config.encoder_dim, kernel_init=nn.initializers.xavier_uniform())(
+ inputs)
+
+ if config.conv_residual_dropout_rate is None:
+ conv_residual_dropout_rate = 0.0
+ else:
+ conv_residual_dropout_rate = config.conv_residual_dropout_rate
+ inputs = nn.Dropout(
+ rate=conv_residual_dropout_rate, deterministic=not train)(
+ inputs)
+ return inputs
+
+
+class ConformerBlock(nn.Module):
+ """Implements a single conformer encoder layer.
+
+ High level overview:
+
+ x = x + 0.5 * FeedForward(x)
+ x = x + MHSA(x)
+ x = x + ConvolutionBlock(x)
+ x = x + 0.5 * FeedForward(x)
+
+ y = layer_norm(x)
+
+ """
+ config: ConformerConfig
+
+ @nn.compact
+ def __call__(self,
+ inputs,
+ input_paddings,
+ train,
+ update_batch_norm,
+ use_running_average):
+ config = self.config
+ padding_mask = jnp.expand_dims(1 - input_paddings, -1)
+
+ inputs = inputs + 0.5 * FeedForwardModule(config=self.config)(
+ inputs, padding_mask, train)
+
+ inputs = inputs + MultiHeadedSelfAttention(config=self.config)(
+ inputs, input_paddings, train)
+
+ inputs = inputs + \
+ ConvolutionBlock(config)(inputs,
+ input_paddings,
+ train,
+ update_batch_norm,
+ use_running_average
+ )
+
+ inputs = inputs + 0.5 * FeedForwardModule(config=self.config)(
+ inputs, padding_mask, train)
+
+ if config.use_post_layer_norm:
+ inputs = LayerNorm(dim=config.encoder_dim)(inputs)
+
+ return inputs
+
+
+class Conformer(nn.Module):
+ """Conformer (encoder + decoder) block.
+
+ Takes audio input signals and outputs probability distribution over vocab size
+ for each time step. The output is then fed into a CTC loss which eliminates
+ the need for alignment with targets.
+ """
+ config: ConformerConfig
+
+ def setup(self):
+ self.specaug = spectrum_augmenter.SpecAug(
+ freq_mask_count=self.config.freq_mask_count,
+ freq_mask_max_bins=self.config.freq_mask_max_bins,
+ time_mask_count=self.config.time_mask_count,
+ time_mask_max_frames=self.config.time_mask_max_frames,
+ time_mask_max_ratio=self.config.time_mask_max_ratio,
+ time_masks_per_frame=self.config.time_masks_per_frame,
+ use_dynamic_time_mask_max_frames=self.config
+ .use_dynamic_time_mask_max_frames)
+
+ @nn.compact
+ def __call__(self,
+ inputs,
+ input_paddings,
+ train,
+ update_batch_norm: Optional[bool] = None,
+ use_running_average_bn: Optional[bool] = None):
+ config = self.config
+
+ outputs = inputs
+ output_paddings = input_paddings
+
+ # Set BN args if not supplied for backwards compatibility
+ if update_batch_norm is None:
+ update_batch_norm = train
+ if use_running_average_bn is None:
+ use_running_average_bn = not train
+
+ # Compute normalized log mel spectrograms from input audio signal.
+ preprocessing_config = preprocessor.LibrispeechPreprocessingConfig()
+ outputs, output_paddings = preprocessor.MelFilterbankFrontend(
+ preprocessing_config,
+ per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR,
+ per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR)(
+ outputs, output_paddings)
+
+ # Ablate random parts of input along temporal and frequency dimension
+ # following the specaug procedure in https://arxiv.org/abs/1904.08779.
+ if train and config.use_specaug:
+ outputs, output_paddings = self.specaug(outputs, output_paddings)
+
+ # Subsample input by a factor of 4 by performing strided convolutions.
+ if config.input_dropout_rate is None:
+ input_dropout_rate = 0.1
+ else:
+ input_dropout_rate = config.input_dropout_rate
+ outputs, output_paddings = Subsample(
+ encoder_dim=config.encoder_dim,
+ input_dropout_rate=input_dropout_rate)(
+ outputs, output_paddings, train)
+
+ # Run the conformer encoder layers.
+ for _ in range(config.num_encoder_layers):
+ outputs = ConformerBlock(config)(outputs,
+ output_paddings,
+ train,
+ update_batch_norm,
+ use_running_average_bn)
+
+ outputs = LayerNorm(config.encoder_dim)(outputs)
+ # Run the decoder which in this case is a trivial projection layer.
+ outputs = nn.Dense(
+ config.vocab_size,
+ use_bias=True,
+ kernel_init=nn.initializers.xavier_uniform())(
+ outputs)
+
+ return outputs, output_paddings
\ No newline at end of file
diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models_ref.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models_ref.py
new file mode 100644
index 000000000..7b7c9720a
--- /dev/null
+++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models_ref.py
@@ -0,0 +1,525 @@
+r"""Deepspeech.
+
+This model uses a deepspeech2 network to convert speech to text.
+paper : https://arxiv.org/abs/1512.02595
+
+# BiLSTM code contributed by bastings@
+# github : https://github.com/bastings
+# webpage : https://bastings.github.io/
+"""
+
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+from flax import linen as nn
+from flax import struct
+import jax
+from jax.experimental import rnn
+import jax.numpy as jnp
+
+from algoperf.workloads.librispeech_conformer.librispeech_jax import \
+ librispeech_preprocessor as preprocessor
+from algoperf.workloads.librispeech_conformer.librispeech_jax import \
+ spectrum_augmenter
+
+Array = jnp.ndarray
+StateType = Union[Array, Tuple[Array, ...]]
+PRNGKey = Any
+Shape = Tuple[int]
+Dtype = Any
+Carry = Any
+CarryHistory = Any
+Output = Any
+
+
+@struct.dataclass
+class DeepspeechConfig:
+ """Global hyperparameters used to minimize obnoxious kwarg plumbing."""
+ vocab_size: int = 1024
+ dtype: Any = jnp.float32
+ encoder_dim: int = 512
+ num_lstm_layers: int = 6
+ num_ffn_layers: int = 3
+ conv_subsampling_factor: int = 2
+ conv_subsampling_layers: int = 2
+ use_specaug: bool = True
+ freq_mask_count: int = 2
+ freq_mask_max_bins: int = 27
+ time_mask_count: int = 10
+ time_mask_max_frames: int = 40
+ time_mask_max_ratio: float = 0.05
+ time_masks_per_frame: float = 0.0
+ use_dynamic_time_mask_max_frames: bool = True
+ batch_norm_momentum: float = 0.999
+ batch_norm_epsilon: float = 0.001
+ # If None, defaults to 0.1.
+ input_dropout_rate: Optional[float] = 0.1
+ # If None, defaults to 0.1.
+ feed_forward_dropout_rate: Optional[float] = 0.1
+ enable_residual_connections: bool = True
+ enable_decoder_layer_norm: bool = True
+ bidirectional: bool = True
+ use_tanh: bool = False
+ layernorm_everywhere: bool = False
+
+
+class Subsample(nn.Module):
+ """Module to perform strided convolution in order to subsample inputs.
+
+ Attributes:
+ encoder_dim: model dimension of conformer.
+ input_dropout_rate: dropout rate for inputs.
+ """
+ config: DeepspeechConfig
+
+ @nn.compact
+ def __call__(self, inputs, output_paddings, train):
+ config = self.config
+ outputs = jnp.expand_dims(inputs, axis=-1)
+
+ outputs, output_paddings = Conv2dSubsampling(
+ encoder_dim=config.encoder_dim,
+ dtype=config.dtype,
+ batch_norm_momentum=config.batch_norm_momentum,
+ batch_norm_epsilon=config.batch_norm_epsilon,
+ input_channels=1,
+ output_channels=config.encoder_dim,
+ use_tanh=config.use_tanh
+ )(outputs, output_paddings, train)
+
+ outputs, output_paddings = Conv2dSubsampling(
+ encoder_dim=config.encoder_dim,
+ dtype=config.dtype,
+ batch_norm_momentum=config.batch_norm_momentum,
+ batch_norm_epsilon=config.batch_norm_epsilon,
+ input_channels=config.encoder_dim,
+ output_channels=config.encoder_dim,
+ use_tanh=config.use_tanh)(outputs, output_paddings, train)
+
+ batch_size, subsampled_lengths, subsampled_dims, channels = outputs.shape
+
+ outputs = jnp.reshape(
+ outputs, (batch_size, subsampled_lengths, subsampled_dims * channels))
+
+ outputs = nn.Dense(
+ config.encoder_dim,
+ use_bias=True,
+ kernel_init=nn.initializers.xavier_uniform())(
+ outputs)
+
+ if config.input_dropout_rate is None:
+ input_dropout_rate = 0.1
+ else:
+ input_dropout_rate = config.input_dropout_rate
+ outputs = nn.Dropout(
+ rate=input_dropout_rate, deterministic=not train)(
+ outputs)
+
+ return outputs, output_paddings
+
+
+class Conv2dSubsampling(nn.Module):
+ """Helper module used in Subsample layer.
+
+ 1) Performs strided convolution over inputs and then applies non-linearity.
+ 2) Also performs strided convolution over input_paddings to return the correct
+ paddings for downstream layers.
+ """
+ input_channels: int = 0
+ output_channels: int = 0
+ filter_stride: List[int] = (2, 2)
+ padding: str = 'SAME'
+ encoder_dim: int = 0
+ dtype: Any = jnp.float32
+ batch_norm_momentum: float = 0.999
+ batch_norm_epsilon: float = 0.001
+ use_tanh: bool = False
+
+ def setup(self):
+ self.filter_shape = (3, 3, self.input_channels, self.output_channels)
+ self.kernel = self.param('kernel',
+ nn.initializers.xavier_uniform(),
+ self.filter_shape)
+ self.bias = self.param(
+ 'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels)
+
+ @nn.compact
+ def __call__(self, inputs, paddings, train):
+ # Computing strided convolution to subsample inputs.
+ feature_group_count = inputs.shape[3] // self.filter_shape[2]
+ outputs = jax.lax.conv_general_dilated(
+ lhs=inputs,
+ rhs=self.kernel,
+ window_strides=self.filter_stride,
+ padding=self.padding,
+ rhs_dilation=(1, 1),
+ dimension_numbers=('NHWC', 'HWIO', 'NHWC'),
+ feature_group_count=feature_group_count)
+
+ outputs += jnp.reshape(self.bias, (1,) * (outputs.ndim - 1) + (-1,))
+
+ if self.use_tanh:
+ outputs = nn.tanh(outputs)
+ else:
+ outputs = nn.relu(outputs)
+
+ # Computing correct paddings post input convolution.
+ input_length = paddings.shape[1]
+ stride = self.filter_stride[0]
+
+ pad_len = (input_length + stride - 1) // stride * stride - input_length
+ out_padding = jax.lax.conv_general_dilated(
+ lhs=paddings[:, :, None],
+ rhs=jnp.ones([1, 1, 1]),
+ window_strides=self.filter_stride[:1],
+ padding=[(0, pad_len)],
+ dimension_numbers=('NHC', 'HIO', 'NHC'))
+ out_padding = jnp.squeeze(out_padding, axis=-1)
+
+ # Mask outputs by correct paddings to ensure padded elements in inputs map
+ # to padded value in outputs.
+ outputs = outputs * (1.0 -
+ jnp.expand_dims(jnp.expand_dims(out_padding, -1), -1))
+
+ return outputs, out_padding
+
+
+class FeedForwardModule(nn.Module):
+ """Feedforward block of conformer layer."""
+ config: DeepspeechConfig
+
+ @nn.compact
+ def __call__(self, inputs, input_paddings=None, train=False):
+ padding_mask = jnp.expand_dims(1 - input_paddings, -1)
+ config = self.config
+
+ if config.layernorm_everywhere:
+ inputs = LayerNorm(config.encoder_dim)(inputs)
+ else:
+ inputs = BatchNorm(config.encoder_dim,
+ config.dtype,
+ config.batch_norm_momentum,
+ config.batch_norm_epsilon)(inputs,
+ input_paddings,
+ train)
+ inputs = nn.Dense(
+ config.encoder_dim,
+ use_bias=True,
+ kernel_init=nn.initializers.xavier_uniform())(
+ inputs)
+ if config.use_tanh:
+ inputs = nn.tanh(inputs)
+ else:
+ inputs = nn.relu(inputs)
+ inputs *= padding_mask
+
+ if config.feed_forward_dropout_rate is None:
+ feed_forward_dropout_rate = 0.1
+ else:
+ feed_forward_dropout_rate = config.feed_forward_dropout_rate
+ inputs = nn.Dropout(rate=feed_forward_dropout_rate)(
+ inputs, deterministic=not train)
+
+ return inputs
+
+
+class LayerNorm(nn.Module):
+ """Module implementing layer normalization.
+
+ This implementation is same as in this paper:
+ https://arxiv.org/pdf/1607.06450.pdf.
+
+ note: we multiply normalized inputs by (1 + scale) and initialize scale to
+ zeros, this differs from default flax implementation of multiplying by scale
+ and initializing to ones.
+ """
+ dim: int = 0
+ epsilon: float = 1e-6
+
+ def setup(self):
+ self.scale = self.param('scale', nn.initializers.zeros, [self.dim])
+ self.bias = self.param('bias', nn.initializers.zeros, [self.dim])
+
+ @nn.compact
+ def __call__(self, inputs):
+ mean = jnp.mean(inputs, axis=-1, keepdims=True)
+ var = jnp.mean(jnp.square(inputs - mean), axis=-1, keepdims=True)
+
+ normed_inputs = (inputs - mean) * jax.lax.rsqrt(var + self.epsilon)
+ normed_inputs *= (1 + self.scale)
+ normed_inputs += self.bias
+
+ return normed_inputs
+
+
+class BatchNorm(nn.Module):
+ """Implements batch norm respecting input paddings.
+
+ This implementation takes into account input padding by masking inputs before
+ computing mean and variance.
+
+ This is inspired by lingvo jax implementation of BatchNorm:
+ https://github.com/tensorflow/lingvo/blob/84b85514d7ad3652bc9720cb45acfab08604519b/lingvo/jax/layers/normalizations.py#L92
+
+ and the corresponding defaults for momentum and epsilon have been copied over
+ from lingvo.
+ """
+ encoder_dim: int = 0
+ dtype: Any = jnp.float32
+ batch_norm_momentum: float = 0.999
+ batch_norm_epsilon: float = 0.001
+
+ def setup(self):
+ dim = self.encoder_dim
+ dtype = self.dtype
+
+ self.ra_mean = self.variable('batch_stats',
+ 'mean',
+ lambda s: jnp.zeros(s, dtype),
+ dim)
+ self.ra_var = self.variable('batch_stats',
+ 'var',
+ lambda s: jnp.ones(s, dtype),
+ dim)
+
+ self.gamma = self.param('scale', nn.initializers.zeros, dim, dtype)
+ self.beta = self.param('bias', nn.initializers.zeros, dim, dtype)
+
+ def _get_default_paddings(self, inputs):
+ """Gets the default paddings for an input."""
+ in_shape = list(inputs.shape)
+ in_shape[-1] = 1
+
+ return jnp.zeros(in_shape, dtype=inputs.dtype)
+
+ @nn.compact
+ def __call__(self, inputs, input_paddings=None, train=False):
+ rank = inputs.ndim
+ reduce_over_dims = list(range(0, rank - 1))
+
+ if input_paddings is None:
+ padding = self._get_default_paddings(inputs)
+ else:
+ padding = jnp.expand_dims(input_paddings, -1)
+
+ momentum = self.batch_norm_momentum
+ epsilon = self.batch_norm_epsilon
+
+ if train:
+ mask = 1.0 - padding
+ sum_v = jnp.sum(inputs * mask, axis=reduce_over_dims, keepdims=True)
+ count_v = jnp.sum(
+ jnp.ones_like(inputs) * mask, axis=reduce_over_dims, keepdims=True)
+
+ sum_v = jax.lax.psum(sum_v, axis_name='batch')
+ count_v = jax.lax.psum(count_v, axis_name='batch')
+
+ count_v = jnp.maximum(count_v, 1.0)
+ mean = sum_v / count_v
+ variance = (inputs - mean) * (inputs - mean) * mask
+
+ sum_vv = jnp.sum(variance, axis=reduce_over_dims, keepdims=True)
+
+ sum_vv = jax.lax.psum(sum_vv, axis_name='batch')
+ var = sum_vv / count_v
+
+ self.ra_mean.value = momentum * self.ra_mean.value + (1 - momentum) * mean
+ self.ra_var.value = momentum * self.ra_var.value + (1 - momentum) * var
+ else:
+ mean = self.ra_mean.value
+ var = self.ra_var.value
+
+ inv = (1 + self.gamma) / jnp.sqrt(var + epsilon)
+
+ bn_output = (inputs - mean) * inv + self.beta
+ bn_output *= 1.0 - padding
+
+ return bn_output
+ # return inputs
+
+
+class CudnnLSTM(nn.Module):
+ features: int
+ num_layers: int = 1
+ dropout_rate: float = 0.0
+ bidirectional: bool = False
+
+ @nn.compact
+ def __call__(
+ self,
+ inputs: Array,
+ segmentation_mask: Optional[Array] = None,
+ return_carry: Optional[bool] = None,
+ deterministic: bool = False,
+ initial_states: Optional[Tuple[Array, Array]] = None,
+ use_cuda: bool = True,
+ ) -> Union[Array, Tuple[Array, Carry]]:
+
+ if jax.devices()[0].platform != 'gpu':
+ use_cuda = False
+
+ batch_size = inputs.shape[0]
+ input_size = inputs.shape[2]
+ num_directions = 2 if self.bidirectional else 1
+ dropout = 0.0 if deterministic else self.dropout_rate
+
+ weights = self.param(
+ 'weights',
+ rnn.init_lstm_weight,
+ input_size,
+ self.features,
+ self.num_layers,
+ self.bidirectional,
+ )
+
+ if initial_states is None:
+ h_0 = jnp.zeros(
+ (num_directions * self.num_layers, batch_size, self.features),
+ jnp.float32,
+ )
+ c_0 = jnp.zeros(
+ (num_directions * self.num_layers, batch_size, self.features),
+ jnp.float32,
+ )
+ else:
+ h_0, c_0 = initial_states
+
+ if segmentation_mask is not None:
+ seq_lengths = jnp.sum(1 - segmentation_mask, axis=1, dtype=jnp.int32)
+ else:
+ seq_lengths = jnp.full((batch_size,), inputs.shape[1], dtype=jnp.int32)
+
+ if use_cuda:
+ y, h, c = rnn.lstm(
+ x=inputs, h_0=h_0, c_0=c_0, weights=weights,
+ seq_lengths=seq_lengths, input_size=input_size,
+ hidden_size=self.features, num_layers=self.num_layers,
+ dropout=dropout, bidirectional=self.bidirectional,
+ )
+ else:
+ weight_ih, weight_hh, bias_ih, bias_hh = self.unpack_weights(
+ weights, input_size)
+ y, h, c = rnn.lstm_ref(
+ x=inputs, h_0=h_0, c_0=c_0, W_ih=weight_ih, W_hh=weight_hh,
+ b_ih=bias_ih, b_hh=bias_hh, seq_lengths=seq_lengths,
+ input_size=input_size, hidden_size=self.features,
+ num_layers=self.num_layers, dropout=dropout,
+ bidirectional=self.bidirectional,
+ )
+
+ if return_carry:
+ return y, (h, c)
+
+ return y
+
+ @nn.nowrap
+ def unpack_weights(
+ self, weights: Array, input_size: int
+ ) -> Tuple[
+ Dict[int, Array], Dict[int, Array], Dict[int, Array], Dict[int, Array]]:
+ return jax.experimental.rnn.unpack_lstm_weights(
+ weights,
+ input_size,
+ self.features,
+ self.num_layers,
+ self.bidirectional,
+ )
+
+
+class BatchRNN(nn.Module):
+ """Implements a single deepspeech encoder layer.
+ """
+ config: DeepspeechConfig
+
+ @nn.compact
+ def __call__(self, inputs, input_paddings, train):
+ config = self.config
+
+ if config.layernorm_everywhere:
+ inputs = LayerNorm(config.encoder_dim)(inputs)
+ else:
+ inputs = BatchNorm(config.encoder_dim,
+ config.dtype,
+ config.batch_norm_momentum,
+ config.batch_norm_epsilon)(inputs,
+ input_paddings,
+ train)
+ output = CudnnLSTM(
+ features=config.encoder_dim // 2,
+ bidirectional=config.bidirectional,
+ num_layers=1)(inputs, input_paddings)
+
+ return output
+
+
+class Deepspeech(nn.Module):
+ """Conformer (encoder + decoder) block.
+
+ Takes audio input signals and outputs probability distribution over vocab size
+ for each time step. The output is then fed into a CTC loss which eliminates
+ the need for alignment with targets.
+ """
+ config: DeepspeechConfig
+
+ def setup(self):
+ config = self.config
+ self.specaug = spectrum_augmenter.SpecAug(
+ freq_mask_count=config.freq_mask_count,
+ freq_mask_max_bins=config.freq_mask_max_bins,
+ time_mask_count=config.time_mask_count,
+ time_mask_max_frames=config.time_mask_max_frames,
+ time_mask_max_ratio=config.time_mask_max_ratio,
+ time_masks_per_frame=config.time_masks_per_frame,
+ use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames
+ )
+
+ @nn.compact
+ def __call__(self, inputs, input_paddings, train):
+ config = self.config
+
+ outputs = inputs
+ output_paddings = input_paddings
+
+ # Compute normalized log mel spectrograms from input audio signal.
+ preprocessing_config = preprocessor.LibrispeechPreprocessingConfig()
+ outputs, output_paddings = preprocessor.MelFilterbankFrontend(
+ preprocessing_config,
+ per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR,
+ per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR)(outputs,
+ output_paddings)
+
+ # Ablate random parts of input along temporal and frequency dimension
+ # following the specaug procedure in https://arxiv.org/abs/1904.08779.
+ if config.use_specaug and train:
+ outputs, output_paddings = self.specaug(outputs, output_paddings)
+
+ # Subsample input by a factor of 4 by performing strided convolutions.
+ outputs, output_paddings = Subsample(
+ config=config)(outputs, output_paddings, train)
+
+ # Run the lstm layers.
+ for _ in range(config.num_lstm_layers):
+ if config.enable_residual_connections:
+ outputs = outputs + BatchRNN(config)(outputs, output_paddings, train)
+ else:
+ outputs = BatchRNN(config)(outputs, output_paddings, train)
+
+ for _ in range(config.num_ffn_layers):
+ if config.enable_residual_connections:
+ outputs = outputs + FeedForwardModule(config=self.config)(
+ outputs, output_paddings, train)
+ else:
+ outputs = FeedForwardModule(config=self.config)(outputs,
+ output_paddings,
+ train)
+
+ # Run the decoder which in this case is a trivial projection layer.
+ if config.enable_decoder_layer_norm:
+ outputs = LayerNorm(config.encoder_dim)(outputs)
+
+ outputs = nn.Dense(
+ config.vocab_size,
+ use_bias=True,
+ kernel_init=nn.initializers.xavier_uniform())(
+ outputs)
+
+ return outputs, output_paddings
\ No newline at end of file
diff --git a/algoperf/workloads/ogbg/ogbg_jax/models_ref.py b/algoperf/workloads/ogbg/ogbg_jax/models_ref.py
new file mode 100644
index 000000000..f0a9e3dc1
--- /dev/null
+++ b/algoperf/workloads/ogbg/ogbg_jax/models_ref.py
@@ -0,0 +1,88 @@
+# Forked from the init2winit implementation here
+# https://github.com/google/init2winit/blob/master/init2winit/model_lib/gnn.py.
+from typing import Optional, Tuple
+
+from flax import linen as nn
+import jax.numpy as jnp
+import jraph
+
+
+def _make_embed(latent_dim, name):
+
+ def make_fn(inputs):
+ return nn.Dense(features=latent_dim, name=name)(inputs)
+
+ return make_fn
+
+
+def _make_mlp(hidden_dims, dropout, activation_fn):
+ """Creates a MLP with specified dimensions."""
+
+ @jraph.concatenated_args
+ def make_fn(inputs):
+ x = inputs
+ for dim in hidden_dims:
+ x = nn.Dense(features=dim)(x)
+ x = nn.LayerNorm()(x)
+ x = activation_fn(x)
+ x = dropout(x)
+ return x
+
+ return make_fn
+
+
+class GNN(nn.Module):
+ """Defines a graph network.
+ The model assumes the input data is a jraph.GraphsTuple without global
+ variables. The final prediction will be encoded in the globals.
+ """
+ num_outputs: int
+ latent_dim: int = 256
+ hidden_dims: Tuple[int] = (256,)
+ # If None, defaults to 0.1.
+ dropout_rate: Optional[float] = 0.1
+ num_message_passing_steps: int = 5
+ activation_fn_name: str = 'relu'
+
+ @nn.compact
+ def __call__(self, graph, train):
+ if self.dropout_rate is None:
+ dropout_rate = 0.1
+ else:
+ dropout_rate = self.dropout_rate
+ dropout = nn.Dropout(rate=dropout_rate, deterministic=not train)
+
+ graph = graph._replace(
+ globals=jnp.zeros([graph.n_node.shape[0], self.num_outputs]))
+
+ embedder = jraph.GraphMapFeatures(
+ embed_node_fn=_make_embed(self.latent_dim, name='node_embedding'),
+ embed_edge_fn=_make_embed(self.latent_dim, name='edge_embedding'))
+ graph = embedder(graph)
+
+ if self.activation_fn_name == 'relu':
+ activation_fn = nn.relu
+ elif self.activation_fn_name == 'gelu':
+ activation_fn = nn.gelu
+ elif self.activation_fn_name == 'silu':
+ activation_fn = nn.silu
+ else:
+ raise ValueError(
+ f'Invalid activation function name: {self.activation_fn_name}')
+
+ for _ in range(self.num_message_passing_steps):
+ net = jraph.GraphNetwork(
+ update_edge_fn=_make_mlp(
+ self.hidden_dims, dropout=dropout, activation_fn=activation_fn),
+ update_node_fn=_make_mlp(
+ self.hidden_dims, dropout=dropout, activation_fn=activation_fn),
+ update_global_fn=_make_mlp(
+ self.hidden_dims, dropout=dropout, activation_fn=activation_fn))
+
+ graph = net(graph)
+
+ # Map globals to represent the final result
+ decoder = jraph.GraphMapFeatures(embed_global_fn=nn.Dense(self.num_outputs))
+ graph = decoder(graph)
+
+ return graph.globals
\ No newline at end of file
diff --git a/algoperf/workloads/wmt/wmt_jax/models_ref.py b/algoperf/workloads/wmt/wmt_jax/models_ref.py
new file mode 100644
index 000000000..e1f44aaa6
--- /dev/null
+++ b/algoperf/workloads/wmt/wmt_jax/models_ref.py
@@ -0,0 +1,604 @@
+"""Transformer-based machine translation model.
+
+Reference https://github.com/google/flax/tree/main/examples/wmt.
+"""
+
+from typing import Any, Callable, Optional
+
+from flax import linen as nn
+from flax import struct
+from jax import lax
+import jax.numpy as jnp
+import numpy as np
+
+
+@struct.dataclass
+class TransformerConfig:
+ """Global hyperparameters used to minimize obnoxious kwarg plumbing."""
+ share_embeddings: bool = True
+ dtype: Any = jnp.float32
+ vocab_size: int = 32000
+ emb_dim: int = 1024
+ num_heads: int = 16
+ num_layers: int = 6
+ qkv_dim: int = 1024
+ mlp_dim: int = 1024
+ max_len: int = 256
+ activation: Callable = nn.relu
+ glu: bool = False
+ #If None, defaults to 0.1.
+ dropout_rate: Optional[float] = 0.1
+ #If None, defaults to 0.1.
+ attention_dropout_rate: Optional[float] = 0.1
+ attention_temp: float = 1.0
+ deterministic: bool = False
+ decode: bool = False
+ kernel_init: Callable = nn.initializers.xavier_uniform()
+ bias_init: Callable = nn.initializers.normal(stddev=1e-6)
+ posemb_init: Optional[Callable] = None
+ pre_ln: bool = True
+
+
+def shift_right(x, axis=1):
+ """Shift the input to the right by padding on axis 1."""
+ pad_widths = [(0, 0)] * len(x.shape)
+ pad_widths[axis] = (1, 0)
+ padded = jnp.pad(
+ x, pad_widths, mode='constant', constant_values=x.dtype.type(0))
+ return padded[:, :-1]
+
+
+def sinusoidal_init(max_len=2048, min_scale=1.0, max_scale=10000.0):
+ """1D Sinusoidal Position Embedding Initializer.
+
+ Args:
+ max_len: maximum possible length for the input.
+ min_scale: float: minimum frequency-scale in sine grating.
+ max_scale: float: maximum frequency-scale in sine grating.
+
+ Returns:
+ output: init function returning `(1, max_len, d_feature)`
+ """
+
+ def init(key, shape, dtype=np.float32):
+ """Sinusoidal init."""
+ del key, dtype
+ d_feature = shape[-1]
+ pe = np.zeros((max_len, d_feature), dtype=np.float32)
+ position = np.arange(0, max_len)[:, np.newaxis]
+ scale_factor = -np.log(max_scale / min_scale) / (d_feature // 2 - 1)
+ div_term = min_scale * np.exp(np.arange(0, d_feature // 2) * scale_factor)
+ pe[:, :d_feature // 2] = np.sin(position * div_term)
+ pe[:, d_feature // 2:2 * (d_feature // 2)] = np.cos(position * div_term)
+ pe = pe[np.newaxis, :, :] # [1, max_len, d_feature]
+ return jnp.array(pe)
+
+ return init
+
+
+class AddPositionEmbs(nn.Module):
+ """Adds (optionally learned) positional embeddings to the inputs.
+
+ Attributes:
+ config: TransformerConfig dataclass containing hyperparameters.
+ decode: whether to run in single-position autoregressive mode.
+ """
+ config: TransformerConfig
+ decode: bool = False
+
+ @nn.compact
+ def __call__(self, inputs, inputs_positions=None):
+ """Applies AddPositionEmbs module.
+
+ By default this layer uses a fixed sinusoidal embedding table. If a
+ learned position embedding is desired, pass an initializer to
+ posemb_init in the configuration.
+
+ Args:
+ inputs: input data.
+ inputs_positions: input position indices for packed sequences.
+
+ Returns:
+ output: `(bs, timesteps, in_dim)`
+ """
+ cfg = self.config
+ # inputs.shape is (batch_size, seq_len, emb_dim)
+ assert inputs.ndim == 3, ('Number of dimensions should be 3,'
+ f' but it is: {inputs.ndim}')
+ length = inputs.shape[1]
+ pos_emb_shape = (1, cfg.max_len, inputs.shape[-1])
+ if cfg.posemb_init is None:
+ # Use a fixed (non-learned) sinusoidal position embedding.
+ pos_embedding = sinusoidal_init(max_len=cfg.max_len)(None,
+ pos_emb_shape,
+ None)
+ else:
+ pos_embedding = self.param('pos_embedding',
+ cfg.posemb_init,
+ pos_emb_shape)
+ pe = pos_embedding[:, :length, :]
+
+ # We use a cache position index for tracking decoding position.
+ if self.decode:
+ is_initialized = self.has_variable('cache', 'cache_index')
+ cache_index = self.variable('cache',
+ 'cache_index',
+ lambda: jnp.array(0, dtype=jnp.uint32))
+ if is_initialized:
+ i = cache_index.value
+ cache_index.value = i + 1
+ _, _, df = pos_embedding.shape
+ pe = lax.dynamic_slice(pos_embedding, jnp.array((0, i, 0)), (1, 1, df))
+ if inputs_positions is None:
+ # normal unpacked case:
+ return inputs + pe
+ else:
+ # for packed data we need to use known position indices:
+ return inputs + jnp.take(pe[0], inputs_positions, axis=0)
+
+
+class MlpBlock(nn.Module):
+ """Transformer MLP / feed-forward block.
+
+ Attributes:
+ config: TransformerConfig dataclass containing hyperparameters.
+ out_dim: optionally specify out dimension.
+ """
+ config: TransformerConfig
+ out_dim: Optional[int] = None
+
+ @nn.compact
+ def __call__(self, inputs):
+ """Applies Transformer MlpBlock module."""
+ cfg = self.config
+ actual_out_dim = (
+ inputs.shape[-1] if self.out_dim is None else self.out_dim)
+ x = nn.Dense(
+ cfg.mlp_dim,
+ dtype=cfg.dtype,
+ kernel_init=cfg.kernel_init,
+ bias_init=cfg.bias_init)(
+ inputs)
+ x = cfg.activation(x)
+ if cfg.glu:
+ y = nn.Dense(
+ cfg.mlp_dim,
+ dtype=cfg.dtype,
+ kernel_init=cfg.kernel_init,
+ bias_init=cfg.bias_init)(
+ inputs)
+ x = x * y
+ if cfg.dropout_rate is None:
+ dropout_rate = 0.1
+ else:
+ dropout_rate = cfg.dropout_rate
+ x = nn.Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic)
+ output = nn.Dense(
+ actual_out_dim,
+ dtype=cfg.dtype,
+ kernel_init=cfg.kernel_init,
+ bias_init=cfg.bias_init)(
+ x)
+ output = nn.Dropout(rate=dropout_rate)(
+ output, deterministic=cfg.deterministic)
+ return output
+
+
+class Encoder1DBlock(nn.Module):
+ """Transformer encoder layer.
+
+ Attributes:
+ config: TransformerConfig dataclass containing hyperparameters.
+ """
+ config: TransformerConfig
+
+ @nn.compact
+ def __call__(self, inputs, encoder_mask=None):
+ """Applies Encoder1DBlock module.
+
+ Args:
+ inputs: input data.
+ encoder_mask: encoder self-attention mask.
+
+ Returns:
+ output after transformer encoder block.
+ """
+ cfg = self.config
+ pre_ln = cfg.pre_ln
+
+ # Attention block.
+ assert inputs.ndim == 3
+ x = nn.LayerNorm(dtype=cfg.dtype)(inputs) if pre_ln else inputs
+ if cfg.attention_dropout_rate is None:
+ attention_dropout_rate = 0.1
+ else:
+ attention_dropout_rate = cfg.attention_dropout_rate
+ x = nn.MultiHeadDotProductAttention(
+ num_heads=cfg.num_heads,
+ dtype=cfg.dtype,
+ qkv_features=cfg.qkv_dim,
+ kernel_init=cfg.kernel_init,
+ bias_init=cfg.bias_init,
+ use_bias=False,
+ broadcast_dropout=False,
+ dropout_rate=attention_dropout_rate,
+ deterministic=cfg.deterministic)(
+ cfg.attention_temp * x, x, mask=encoder_mask)
+
+ if cfg.dropout_rate is None:
+ dropout_rate = 0.1
+ else:
+ dropout_rate = cfg.dropout_rate
+ x = nn.Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic)
+ x = x + inputs
+ if not pre_ln:
+ x = nn.LayerNorm(dtype=cfg.dtype)(x)
+
+ # MLP block.
+ y = nn.LayerNorm(dtype=cfg.dtype)(x) if pre_ln else x
+ y = MlpBlock(config=cfg)(y)
+
+ return x + y if pre_ln else nn.LayerNorm(dtype=cfg.dtype)(x + y)
+
+
+class EncoderDecoder1DBlock(nn.Module):
+ """Transformer encoder-decoder layer.
+
+ Attributes:
+ config: TransformerConfig dataclass containing hyperparameters.
+ """
+ config: TransformerConfig
+
+ @nn.compact
+ def __call__(self,
+ targets,
+ encoded,
+ decoder_mask=None,
+ encoder_decoder_mask=None):
+ """Applies EncoderDecoder1DBlock module.
+
+ Args:
+ targets: input data for decoder
+ encoded: input data from encoder
+ decoder_mask: decoder self-attention mask.
+ encoder_decoder_mask: encoder-decoder attention mask.
+
+ Returns:
+ output after transformer encoder-decoder block.
+ """
+ cfg = self.config
+ pre_ln = cfg.pre_ln
+
+ # Decoder block.
+ assert targets.ndim == 3
+ x = nn.LayerNorm(dtype=cfg.dtype)(targets) if pre_ln else targets
+
+ if cfg.attention_dropout_rate is None:
+ attention_dropout_rate = 0.1
+ else:
+ attention_dropout_rate = cfg.attention_dropout_rate
+ x = nn.MultiHeadDotProductAttention(
+ num_heads=cfg.num_heads,
+ dtype=cfg.dtype,
+ qkv_features=cfg.qkv_dim,
+ kernel_init=cfg.kernel_init,
+ bias_init=cfg.bias_init,
+ use_bias=False,
+ broadcast_dropout=False,
+ dropout_rate=attention_dropout_rate,
+ deterministic=cfg.deterministic,
+ decode=cfg.decode)(
+ cfg.attention_temp * x, x, mask=decoder_mask)
+ if cfg.dropout_rate is None:
+ dropout_rate = 0.1
+ else:
+ dropout_rate = cfg.dropout_rate
+ x = nn.Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic)
+ x = x + targets
+ if not pre_ln:
+ x = nn.LayerNorm(dtype=cfg.dtype)(x)
+
+ # Encoder-Decoder block.
+ y = nn.LayerNorm(dtype=cfg.dtype)(x) if pre_ln else x
+ y = nn.MultiHeadDotProductAttention(
+ num_heads=cfg.num_heads,
+ dtype=cfg.dtype,
+ qkv_features=cfg.qkv_dim,
+ kernel_init=cfg.kernel_init,
+ bias_init=cfg.bias_init,
+ use_bias=False,
+ broadcast_dropout=False,
+ dropout_rate=attention_dropout_rate,
+ deterministic=cfg.deterministic)(
+ cfg.attention_temp * y, encoded, mask=encoder_decoder_mask)
+
+ y = nn.Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic)
+ y = y + x
+ if not pre_ln:
+ y = nn.LayerNorm(dtype=cfg.dtype)(y)
+
+ # MLP block.
+ z = nn.LayerNorm(dtype=cfg.dtype)(y) if pre_ln else y
+ z = MlpBlock(config=cfg)(z)
+
+ return y + z if pre_ln else nn.LayerNorm(dtype=cfg.dtype)(y + z)
+
+
+class Encoder(nn.Module):
+ """Transformer Model Encoder for sequence to sequence translation.
+
+ Attributes:
+ config: TransformerConfig dataclass containing hyperparameters.
+ shared_embedding: a shared embedding layer to use.
+ """
+ config: TransformerConfig
+ shared_embedding: Any = None
+
+ @nn.compact
+ def __call__(self, inputs, inputs_positions=None, encoder_mask=None):
+ """Applies Transformer model on the inputs.
+
+ Args:
+ inputs: input data
+ inputs_positions: input subsequence positions for packed examples.
+ encoder_mask: decoder self-attention mask.
+
+ Returns:
+ output of a transformer encoder.
+ """
+ cfg = self.config
+ assert inputs.ndim == 2 # (batch, len)
+
+ # Input Embedding
+ if self.shared_embedding is None:
+ input_embed = nn.Embed(
+ num_embeddings=cfg.vocab_size,
+ features=cfg.emb_dim,
+ embedding_init=nn.initializers.normal(stddev=1.0))
+ else:
+ input_embed = self.shared_embedding
+ x = inputs.astype('int32')
+ x = input_embed(x)
+ x = AddPositionEmbs(
+ config=cfg, decode=False, name='posembed_input')(
+ x, inputs_positions=inputs_positions)
+ if cfg.dropout_rate is None:
+ dropout_rate = 0.1
+ else:
+ dropout_rate = cfg.dropout_rate
+ x = nn.Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic)
+
+ x = x.astype(cfg.dtype)
+
+ # Input Encoder
+ for lyr in range(cfg.num_layers):
+ x = Encoder1DBlock(
+ config=cfg, name=f'encoderblock_{lyr}')(x, encoder_mask)
+
+ encoded = (
+ nn.LayerNorm(dtype=cfg.dtype, name='encoder_layernorm')(x)
+ if cfg.pre_ln else x)
+
+ return encoded
+
+
+class Decoder(nn.Module):
+ """Transformer Model Decoder for sequence to sequence translation.
+
+ Attributes:
+ config: TransformerConfig dataclass containing hyperparameters.
+ shared_embedding: a shared embedding layer to use.
+ """
+ config: TransformerConfig
+ shared_embedding: Any = None
+
+ @nn.compact
+ def __call__(self,
+ encoded,
+ targets,
+ targets_positions=None,
+ decoder_mask=None,
+ encoder_decoder_mask=None):
+ """Applies Transformer model on the inputs.
+
+ Args:
+ encoded: encoded input data from encoder.
+ targets: target inputs.
+ targets_positions: input subsequence positions for packed examples.
+ decoder_mask: decoder self-attention mask.
+ encoder_decoder_mask: encoder-decoder attention mask.
+
+ Returns:
+ output of a transformer decoder.
+ """
+ cfg = self.config
+
+ assert encoded.ndim == 3 # (batch, len, depth)
+ assert targets.ndim == 2 # (batch, len)
+
+ # Target Embedding
+ if self.shared_embedding is None:
+ output_embed = nn.Embed(
+ num_embeddings=cfg.vocab_size,
+ features=cfg.emb_dim,
+ embedding_init=nn.initializers.normal(stddev=1.0))
+ else:
+ output_embed = self.shared_embedding
+
+ y = targets.astype('int32')
+ if not cfg.decode:
+ y = shift_right(y)
+ y = output_embed(y)
+ y = AddPositionEmbs(
+ config=cfg, decode=cfg.decode, name='posembed_output')(
+ y, inputs_positions=targets_positions)
+ if cfg.dropout_rate is None:
+ dropout_rate = 0.1
+ else:
+ dropout_rate = cfg.dropout_rate
+ y = nn.Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic)
+
+ y = y.astype(cfg.dtype)
+
+ # Target-Input Decoder
+ for lyr in range(cfg.num_layers):
+ y = EncoderDecoder1DBlock(
+ config=cfg, name=f'encoderdecoderblock_{lyr}')(
+ y,
+ encoded,
+ decoder_mask=decoder_mask,
+ encoder_decoder_mask=encoder_decoder_mask)
+ y = (
+ nn.LayerNorm(dtype=cfg.dtype, name='encoderdecoder_layernorm')(y)
+ if cfg.pre_ln else y)
+
+ # Use the transpose of embedding matrix for logit transform.
+ logits = output_embed.attend(y.astype(jnp.float32))
+ # Correctly normalize pre-softmax logits for this shared case.
+ logits = logits / jnp.sqrt(y.shape[-1])
+ return logits
+
+
+class Transformer(nn.Module):
+ """Transformer Model for sequence to sequence translation.
+
+ Attributes:
+ config: TransformerConfig dataclass containing hyperparameters.
+ """
+ config: TransformerConfig
+
+ def setup(self):
+ cfg = self.config
+
+ if cfg.share_embeddings:
+ if cfg.vocab_size is not None:
+ assert cfg.vocab_size == cfg.vocab_size, (
+ "can't share embedding with different vocab sizes.")
+ self.shared_embedding = nn.Embed(
+ num_embeddings=cfg.vocab_size,
+ features=cfg.emb_dim,
+ embedding_init=nn.initializers.normal(stddev=1.0))
+ else:
+ self.shared_embedding = None
+
+ self.encoder = Encoder(config=cfg, shared_embedding=self.shared_embedding)
+ self.decoder = Decoder(config=cfg, shared_embedding=self.shared_embedding)
+
+ def encode(self, inputs, inputs_positions=None, inputs_segmentation=None):
+ """Applies Transformer encoder-branch on the inputs.
+
+ Args:
+ inputs: input data.
+ inputs_positions: input subsequence positions for packed examples.
+ inputs_segmentation: input segmentation info for packed examples.
+
+ Returns:
+ encoded feature array from the transformer encoder.
+ """
+ cfg = self.config
+ # Make padding attention mask.
+ encoder_mask = nn.make_attention_mask(
+ inputs > 0, inputs > 0, dtype=cfg.dtype)
+ # Add segmentation block-diagonal attention mask if using segmented data.
+ if inputs_segmentation is not None:
+ encoder_mask = nn.combine_masks(
+ encoder_mask,
+ nn.make_attention_mask(
+ inputs_segmentation,
+ inputs_segmentation,
+ jnp.equal,
+ dtype=cfg.dtype))
+ return self.encoder(
+ inputs, inputs_positions=inputs_positions, encoder_mask=encoder_mask)
+
+ def decode(
+ self,
+ encoded,
+ inputs, # only needed for masks
+ targets,
+ targets_positions=None,
+ inputs_segmentation=None,
+ targets_segmentation=None):
+ """Applies Transformer decoder-branch on encoded-input and target.
+
+ Args:
+ encoded: encoded input data from encoder.
+ inputs: input data (only needed for masking).
+ targets: target data.
+ targets_positions: target subsequence positions for packed examples.
+ inputs_segmentation: input segmentation info for packed examples.
+ targets_segmentation: target segmentation info for packed examples.
+
+ Returns:
+ logits array from transformer decoder.
+ """
+ cfg = self.config
+
+ # Make padding attention masks.
+ if cfg.decode:
+ decoder_mask = None
+ encoder_decoder_mask = nn.make_attention_mask(
+ jnp.ones_like(targets) > 0, inputs > 0, dtype=cfg.dtype)
+ else:
+ decoder_mask = nn.combine_masks(
+ nn.make_attention_mask(targets > 0, targets > 0, dtype=cfg.dtype),
+ nn.make_causal_mask(targets, dtype=cfg.dtype))
+ encoder_decoder_mask = nn.make_attention_mask(
+ targets > 0, inputs > 0, dtype=cfg.dtype)
+
+ # Add segmentation block-diagonal attention masks if using segmented data.
+ if inputs_segmentation is not None:
+ decoder_mask = nn.combine_masks(
+ decoder_mask,
+ nn.make_attention_mask(
+ targets_segmentation,
+ targets_segmentation,
+ jnp.equal,
+ dtype=cfg.dtype))
+ encoder_decoder_mask = nn.combine_masks(
+ encoder_decoder_mask,
+ nn.make_attention_mask(
+ targets_segmentation,
+ inputs_segmentation,
+ jnp.equal,
+ dtype=cfg.dtype))
+ logits = self.decoder(
+ encoded,
+ targets,
+ targets_positions=targets_positions,
+ decoder_mask=decoder_mask,
+ encoder_decoder_mask=encoder_decoder_mask)
+ return logits.astype(self.config.dtype)
+
+ def __call__(self,
+ inputs,
+ targets,
+ inputs_positions=None,
+ targets_positions=None,
+ inputs_segmentation=None,
+ targets_segmentation=None):
+ """Applies Transformer model on the inputs.
+
+ Args:
+ inputs: input data.
+ targets: target data.
+ inputs_positions: input subsequence positions for packed examples.
+ targets_positions: target subsequence positions for packed examples.
+ inputs_segmentation: input segmentation info for packed examples.
+ targets_segmentation: target segmentation info for packed examples.
+
+ Returns:
+ logits array from full transformer.
+ """
+ encoded = self.encode(
+ inputs,
+ inputs_positions=inputs_positions,
+ inputs_segmentation=inputs_segmentation)
+
+ return self.decode(
+ encoded,
+ inputs, # only used for masks
+ targets,
+ targets_positions=targets_positions,
+ inputs_segmentation=inputs_segmentation,
+ targets_segmentation=targets_segmentation)
\ No newline at end of file
diff --git a/tests/dropout_fix/criteo1tb_jax/test_model_equivalence.py b/tests/dropout_fix/criteo1tb_jax/test_model_equivalence.py
new file mode 100644
index 000000000..10aeaa650
--- /dev/null
+++ b/tests/dropout_fix/criteo1tb_jax/test_model_equivalence.py
@@ -0,0 +1,156 @@
+"""
+Runs fwd pass with random input for our DLRM models and compares outputs.
+Run it as:
+ python3 tests/dropout_fix/criteo1tb_jax/test_model_equivalence.py
+"""
+
+import os
+
+from absl.testing import absltest
+from absl.testing import parameterized
+import jax
+import jax.numpy as jnp
+# import equinox as eqx
+
+from jax.tree_util import tree_structure, tree_leaves, tree_map
+
+
+def pytrees_are_equal(a, b, rtol=1e-5, atol=1e-8):
+ """
+ A custom function to check if two PyTrees are equal, handling floats with a tolerance.
+ """
+ # 1. Check if the structures are the same
+ if tree_structure(a) != tree_structure(b):
+ return False
+
+ # 2. Define a comparison function for leaves
+ def leaf_comparator(x, y):
+ # Use allclose for floating-point JAX arrays
+ if isinstance(x, jnp.ndarray) and jnp.issubdtype(x.dtype, jnp.floating):
+ return jnp.allclose(x, y, rtol=rtol, atol=atol)
+ # Use standard equality for everything else
+ else:
+ return x == y
+
+ # 3. Map the comparison function over the trees and check if all results are True
+ # We also need to flatten the results of the tree_map and check if all are True
+ comparison_tree = tree_map(leaf_comparator, a, b)
+ all_equal = all(tree_leaves(comparison_tree))
+
+ return all_equal
+
+from algoperf.workloads.criteo1tb.criteo1tb_jax.models_ref import \
+ DLRMResNet as OriginalDLRMResNet
+from algoperf.workloads.criteo1tb.criteo1tb_jax.models_ref import \
+ DlrmSmall as OriginalDlrmSmall
+from algoperf.workloads.criteo1tb.criteo1tb_jax.models import \
+ DLRMResNet as CustomDLRMResNet
+from algoperf.workloads.criteo1tb.criteo1tb_jax.models import \
+ DlrmSmall as CustomDlrmSmall
+
+BATCH, DENSE, SPARSE = 16, 13, 26
+FEATURES = DENSE + SPARSE
+VOCAB = 1000
+DEVICE = 'cuda'
+SEED = 1996
+
+
+class ModelEquivalenceTest(parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ dict(
+ testcase_name='DLRMResNet, p=0.0',
+ model='dlrm_resnet',
+ dropout_rate=0.0),
+ dict(
+ testcase_name='DlrmSmall, p=0.0',
+ model='dlrm_small',
+ dropout_rate=0.0),
+ dict(
+ testcase_name='DLRMResNet, p=0.1',
+ model='dlrm_resnet',
+ dropout_rate=0.1),
+ dict(
+ testcase_name='DlrmSmall, p=0.1',
+ model='dlrm_small',
+ dropout_rate=0.1),
+ )
+ def test_forward(self, model, dropout_rate):
+ OrigCls, CustCls = (
+ (OriginalDLRMResNet, CustomDLRMResNet)
+ if model == 'dlrm_resnet'
+ else (OriginalDlrmSmall, CustomDlrmSmall)
+ )
+
+ # init model
+ rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3)
+ fake_batch = jnp.ones((2, 39))
+ assert dropout_rate == 0.1
+ orig_model = OrigCls(vocab_size=VOCAB, dropout_rate=dropout_rate)
+ cust_model = CustCls(vocab_size=VOCAB)
+ initial_params_original = orig_model.init({'params': rng},
+ fake_batch,
+ train=False)
+ initial_params_custom = cust_model.init({'params': rng},
+ fake_batch,
+ train=False)
+ assert pytrees_are_equal(
+ initial_params_original, initial_params_custom, rtol=1e-6)
+
+ x = jax.random.normal(data_rng, shape=(BATCH, FEATURES))
+
+ for mode in ('train', 'eval'):
+ train = mode == 'train'
+ y1 = orig_model.apply(
+ initial_params_original,
+ x,
+ train=train,
+ rngs={'dropout': dropout_rng})
+ y2 = cust_model.apply(
+ initial_params_custom,
+ x,
+ train=train,
+ dropout_rate=dropout_rate,
+ rngs={'dropout': dropout_rng})
+
+ assert jnp.allclose(y1, y2, atol=1e-3, rtol=1e-3)
+
+ @parameterized.named_parameters(
+ dict(testcase_name='DLRMResNet, default', model='dlrm_resnet'),
+ dict(testcase_name='DlrmSmall, default', model='dlrm_small'),
+ )
+ def test_default_dropout(self, model):
+ """Test default dropout_rate."""
+ OrigCls, CustCls = (
+ (OriginalDLRMResNet, CustomDLRMResNet)
+ if model == 'dlrm_resnet'
+ else (OriginalDlrmSmall, CustomDlrmSmall)
+ )
+ rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3)
+ fake_batch = jnp.ones((2, 39))
+ orig_model = OrigCls(vocab_size=VOCAB)
+ cust_model = CustCls(vocab_size=VOCAB)
+ initial_params_original = orig_model.init({'params': rng},
+ fake_batch,
+ train=False)
+ initial_params_custom = cust_model.init({'params': rng},
+ fake_batch,
+ train=False)
+
+ x = jax.random.normal(data_rng, shape=(BATCH, FEATURES))
+
+ for mode in ('train', 'eval'):
+ train = mode == 'train'
+ y1 = orig_model.apply(
+ initial_params_original,
+ x,
+ train=train,
+ rngs={'dropout': dropout_rng})
+ y2 = cust_model.apply(
+ initial_params_custom, x, train=train, rngs={'dropout': dropout_rng})
+
+ assert jnp.allclose(y1, y2, atol=0, rtol=0)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/tests/dropout_fix/fastmri_jax/test_model_equivalence.py b/tests/dropout_fix/fastmri_jax/test_model_equivalence.py
new file mode 100644
index 000000000..1d318e8c6
--- /dev/null
+++ b/tests/dropout_fix/fastmri_jax/test_model_equivalence.py
@@ -0,0 +1,130 @@
+"""
+Runs fwd pass with random input for FASTMRI U-Net models and compares outputs.
+Run it as:
+ python3 tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py
+"""
+
+import os
+
+from absl.testing import absltest
+from absl.testing import parameterized
+import torch
+from torch.testing import assert_close
+
+from algoperf.workloads.fastmri.fastmri_pytorch.models import \
+ UNet as OriginalUNet
+from algoperf.workloads.fastmri.fastmri_pytorch.models_dropout import \
+ UNet as CustomUNet
+
+BATCH, IN_CHANS, H, W = 4, 1, 256, 256
+OUT_CHANS, C, LAYERS = 1, 32, 4
+DEVICE = 'cuda'
+TORCH_COMPILE = False
+SEED = 1996
+
+os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
+torch.backends.cudnn.benchmark = False
+torch.backends.cudnn.deterministic = True
+torch.use_deterministic_algorithms(True)
+
+
+class FastMRIModeEquivalenceTest(parameterized.TestCase):
+
+ def fwd_pass(self, orig, cust, dropout_rate):
+ x = torch.randn(BATCH, IN_CHANS, H, W, device=DEVICE)
+ for mode in ('train', 'eval'):
+ getattr(orig, mode)()
+ getattr(cust, mode)()
+ torch.manual_seed(0)
+ y1 = orig(x)
+ torch.manual_seed(0)
+ y2 = cust(x, dropout_rate)
+ assert_close(y1, y2, atol=0, rtol=0)
+ if mode == 'eval': # one extra test: omit dropout at eval
+ torch.manual_seed(0)
+ y2 = cust(x)
+ assert_close(y1, y2, atol=0, rtol=0)
+
+ @parameterized.named_parameters(
+ dict(testcase_name='p=0.0', dropout_rate=0.0),
+ dict(testcase_name='p=0.1', dropout_rate=0.1),
+ dict(testcase_name='p=0.7', dropout_rate=0.7),
+ dict(testcase_name='p=1.0', dropout_rate=1.0),
+ )
+ def test_dropout_values(self, dropout_rate):
+ """Test different values of dropout_rate."""
+
+ torch.manual_seed(SEED)
+ orig = OriginalUNet(
+ IN_CHANS, OUT_CHANS, C, LAYERS, dropout_rate=dropout_rate).to(DEVICE)
+
+ torch.manual_seed(SEED)
+ cust = CustomUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE)
+
+ cust.load_state_dict(orig.state_dict()) # sync weights
+ if TORCH_COMPILE:
+ orig = torch.compile(orig)
+ cust = torch.compile(cust)
+
+ self.fwd_pass(orig, cust, dropout_rate)
+
+ @parameterized.named_parameters(
+ dict(testcase_name='default', use_tanh=False, use_layer_norm=False),
+ dict(testcase_name='tanh', use_tanh=True, use_layer_norm=False),
+ dict(testcase_name='layer_norm', use_tanh=False, use_layer_norm=True),
+ dict(testcase_name='both', use_tanh=True, use_layer_norm=True),
+ )
+ def test_arch_configs(self, use_tanh, use_layer_norm):
+ """Test different architecture configurations, fixed dropout_rate."""
+ dropout_rate = 0.1
+
+ torch.manual_seed(SEED)
+ orig = OriginalUNet(
+ IN_CHANS,
+ OUT_CHANS,
+ C,
+ LAYERS,
+ dropout_rate=dropout_rate,
+ use_tanh=use_tanh,
+ use_layer_norm=use_layer_norm).to(DEVICE)
+
+ torch.manual_seed(SEED)
+ cust = CustomUNet(
+ IN_CHANS,
+ OUT_CHANS,
+ C,
+ LAYERS,
+ use_tanh=use_tanh,
+ use_layer_norm=use_layer_norm).to(DEVICE)
+
+ cust.load_state_dict(orig.state_dict()) # sync weights
+ if TORCH_COMPILE:
+ orig = torch.compile(orig)
+ cust = torch.compile(cust)
+
+ self.fwd_pass(orig, cust, dropout_rate)
+
+ @parameterized.named_parameters(
+ dict(testcase_name=''),)
+ def test_default_dropout(self):
+ """Test default dropout_rate."""
+
+ torch.manual_seed(SEED)
+ orig = OriginalUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE)
+ torch.manual_seed(SEED)
+ cust = CustomUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE)
+ cust.load_state_dict(orig.state_dict()) # sync weights
+
+ x = torch.randn(BATCH, IN_CHANS, H, W, device=DEVICE)
+ for mode in ('train', 'eval'):
+ getattr(orig, mode)()
+ getattr(cust, mode)()
+ torch.manual_seed(0)
+ y1 = orig(x)
+ torch.manual_seed(0)
+ y2 = cust(x)
+ assert_close(y1, y2, atol=0, rtol=0)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/tests/dropout_fix/imagenet_vit_jax/test_model_equivalence.py b/tests/dropout_fix/imagenet_vit_jax/test_model_equivalence.py
new file mode 100644
index 000000000..f51eaec7e
--- /dev/null
+++ b/tests/dropout_fix/imagenet_vit_jax/test_model_equivalence.py
@@ -0,0 +1,138 @@
+"""
+Runs fwd pass with random input for FASTMRI U-Net models and compares outputs.
+Run it as:
+ python3 tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py
+"""
+
+import itertools
+import os
+
+from absl.testing import absltest
+from absl.testing import parameterized
+import torch
+from torch.testing import assert_close
+
+from algoperf.workloads.imagenet_vit.imagenet_pytorch.models import \
+ ViT as OriginalVit
+from algoperf.workloads.imagenet_vit.imagenet_pytorch.models_dropout import \
+ ViT as CustomVit
+
+# Model / test hyper-params
+BATCH, C, H, W = 4, 3, 224, 224 # input shape (N,C,H,W)
+WIDTH, DEPTH, HEADS = 256, 4, 8
+DROPOUT_RATE = None
+DEVICE = 'cuda'
+SEED = 1996
+
+os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
+torch.backends.cudnn.benchmark = False
+torch.backends.cudnn.deterministic = True
+torch.use_deterministic_algorithms(True)
+
+
+class ImageNetVitModeEquivalenceTest(parameterized.TestCase):
+
+ def fwd_pass(self, orig, cust, dropout_rate):
+ x = torch.randn(BATCH, C, H, W, device=DEVICE)
+ for mode in ('train', 'eval'):
+ getattr(orig, mode)()
+ getattr(cust, mode)()
+ torch.manual_seed(0)
+ y1 = orig(x)
+ torch.manual_seed(0)
+ y2 = cust(x, dropout_rate)
+ assert_close(y1, y2, atol=0, rtol=0)
+ if mode == 'eval': # one extra test: omit dropout at eval
+ torch.manual_seed(0)
+ y2 = cust(x, dropout_rate)
+ assert_close(y1, y2, atol=0, rtol=0)
+
+ @parameterized.named_parameters(
+ dict(testcase_name='p=0.0', dropout_rate=0.0),
+ dict(testcase_name='p=0.1', dropout_rate=0.1),
+ dict(testcase_name='p=0.6', dropout_rate=0.6),
+ dict(testcase_name='p=1.0', dropout_rate=1.0),
+ )
+ def test_dropout_values(self, dropout_rate):
+ """Test different dropout_rates."""
+
+ torch.manual_seed(SEED)
+ orig = OriginalVit(
+ width=WIDTH,
+ depth=DEPTH,
+ num_heads=HEADS,
+ dropout_rate=dropout_rate,
+ ).to(DEVICE)
+
+ torch.manual_seed(SEED)
+ cust = CustomVit(
+ width=WIDTH,
+ depth=DEPTH,
+ num_heads=HEADS,
+ ).to(DEVICE)
+
+ cust.load_state_dict(orig.state_dict()) # sync weights
+ self.fwd_pass(orig, cust, dropout_rate)
+
+ @parameterized.named_parameters([
+ dict(
+ testcase_name=f"GLU={use_glu}_LN={use_post_ln}_MAP={use_map}",
+ use_glu=use_glu,
+ use_post_ln=use_post_ln,
+ use_map=use_map,
+ ) for use_glu,
+ use_post_ln,
+ use_map in itertools.product([False, True], repeat=3)
+ ])
+ def test_arch(self, use_glu, use_post_ln, use_map):
+ """Test different architecture configurations, fixed dropout_rate."""
+ dropout_rate = 0.1
+
+ torch.manual_seed(SEED)
+ orig = OriginalVit(
+ width=WIDTH,
+ depth=DEPTH,
+ num_heads=HEADS,
+ use_glu=use_glu,
+ use_post_layer_norm=use_post_ln,
+ use_map=use_map,
+ dropout_rate=dropout_rate,
+ ).to(DEVICE)
+
+ torch.manual_seed(SEED)
+ cust = CustomVit(
+ width=WIDTH,
+ depth=DEPTH,
+ num_heads=HEADS,
+ use_glu=use_glu,
+ use_post_layer_norm=use_post_ln,
+ use_map=use_map,
+ ).to(DEVICE)
+
+ cust.load_state_dict(orig.state_dict()) # sync weights
+ self.fwd_pass(orig, cust, dropout_rate)
+
+ @parameterized.named_parameters(
+ dict(testcase_name=''),)
+ def test_default_dropout(self):
+ """Test default dropout_rate."""
+
+ torch.manual_seed(SEED)
+ orig = OriginalVit(width=WIDTH, depth=DEPTH, num_heads=HEADS).to(DEVICE)
+ torch.manual_seed(SEED)
+ cust = CustomVit(width=WIDTH, depth=DEPTH, num_heads=HEADS).to(DEVICE)
+ cust.load_state_dict(orig.state_dict()) # sync weights
+
+ x = torch.randn(BATCH, C, H, W, device=DEVICE)
+ for mode in ('train', 'eval'):
+ getattr(orig, mode)()
+ getattr(cust, mode)()
+ torch.manual_seed(0)
+ y1 = orig(x)
+ torch.manual_seed(0)
+ y2 = cust(x)
+ assert_close(y1, y2, atol=0, rtol=0)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/tests/dropout_fix/librispeech_conformer_jax/test_model_equivalence.py b/tests/dropout_fix/librispeech_conformer_jax/test_model_equivalence.py
new file mode 100644
index 000000000..02f3a3d84
--- /dev/null
+++ b/tests/dropout_fix/librispeech_conformer_jax/test_model_equivalence.py
@@ -0,0 +1,84 @@
+"""
+Runs fwd pass with random input for LIBRISPEECH Conformer models and compares outputs.
+Run with:
+ python3 tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py
+
+NOTE: we don't test for default dropout_rate values, since they changed.
+"""
+
+import os
+
+from absl.testing import absltest
+from absl.testing import parameterized
+import torch
+from torch.testing import assert_close
+
+from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \
+ ConformerConfig as OriginalConfig
+from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \
+ ConformerEncoderDecoder as OriginalModel
+from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models_dropout import \
+ ConformerConfig as CustomConfig
+from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models_dropout import \
+ ConformerEncoderDecoder as CustomModel
+
+N_LAYERS = 3
+B, T = 32, 36_000
+DEVICE = 'cuda'
+
+os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
+torch.backends.cudnn.benchmark = False
+torch.backends.cudnn.deterministic = True
+torch.use_deterministic_algorithms(mode=True)
+SEED = 1996
+
+
+class ConformerEquivalenceTest(parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ dict(testcase_name='p=0.0', dropout_rate=0.0),
+ dict(testcase_name='p=0.2', dropout_rate=0.2),
+ dict(testcase_name='p=0.7', dropout_rate=0.7),
+ dict(testcase_name='p=1.0', dropout_rate=1.0),
+ )
+ def test_forward(self, dropout_rate):
+
+ torch.manual_seed(SEED)
+ orig = OriginalModel(
+ OriginalConfig(
+ num_encoder_layers=N_LAYERS,
+ attention_residual_dropout_rate=dropout_rate,
+ conv_residual_dropout_rate=dropout_rate,
+ feed_forward_residual_dropout_rate=dropout_rate,
+ input_dropout_rate=dropout_rate,
+ )).to(DEVICE)
+
+ torch.manual_seed(SEED)
+ cust = CustomModel(CustomConfig(num_encoder_layers=N_LAYERS)).to(DEVICE)
+
+ orig.load_state_dict(cust.state_dict()) # sync weights
+
+ x = torch.randn(B, T, device=DEVICE)
+ paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE)
+
+ for mode in ('train', 'eval'):
+ getattr(orig, mode)()
+ getattr(cust, mode)()
+
+ torch.manual_seed(SEED)
+ y1, p1 = orig(x, paddings)
+ torch.manual_seed(SEED)
+ y2, p2 = cust(x, paddings, dropout_rate=dropout_rate)
+
+ assert_close(y1, y2, atol=0, rtol=0)
+ assert_close(p1, p2, atol=0, rtol=0)
+
+ if mode == 'eval': # one extra test: omit dropout at eval
+ torch.manual_seed(SEED)
+ y2, p2 = cust(x, paddings)
+ assert_close(y1, y2, atol=0, rtol=0)
+ assert_close(p1, p2, atol=0, rtol=0)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/tests/dropout_fix/librispeech_deepspeech_jax/test_model_equivalence.py b/tests/dropout_fix/librispeech_deepspeech_jax/test_model_equivalence.py
new file mode 100644
index 000000000..7d6a94592
--- /dev/null
+++ b/tests/dropout_fix/librispeech_deepspeech_jax/test_model_equivalence.py
@@ -0,0 +1,124 @@
+"""
+Runs fwd pass with random input for LIBRISPEECH Deepspeech models and compares outputs.
+Run with:
+ python3 tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py
+
+`dropout_rate` controls the following args:
+- `input_dropout_rate` (if None, 0.1
+- `feed_forward_dropout_rate` (if None, 0.1)
+"""
+
+import os
+
+from absl.testing import absltest
+from absl.testing import parameterized
+import torch
+from torch.testing import assert_close
+
+from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import \
+ DeepspeechConfig as OriginalConfig
+from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import \
+ DeepspeechEncoderDecoder as OriginalModel
+from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models_dropout import \
+ DeepspeechConfig as CustomConfig
+from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models_dropout import \
+ DeepspeechEncoderDecoder as CustomModel
+
+B, T = 32, 30_000
+DEVICE = 'cuda'
+TORCH_COMPILE = False
+
+os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
+torch.backends.cudnn.benchmark = False
+torch.backends.cudnn.deterministic = True
+torch.use_deterministic_algorithms(mode=True)
+SEED = 1996
+
+
+class DeepSpeechEquivalenceTest(parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ dict(testcase_name='p=0.0', dropout_rate=0.0),
+ dict(testcase_name='p=0.2', dropout_rate=0.2),
+ dict(testcase_name='p=0.7', dropout_rate=0.7),
+ dict(testcase_name='p=1.0', dropout_rate=1.0),
+ )
+ def test_forward(self, dropout_rate):
+ """Test different dropout_rate values."""
+
+ torch.manual_seed(SEED)
+ orig = OriginalModel(
+ OriginalConfig(
+ num_lstm_layers=2,
+ num_ffn_layers=2,
+ input_dropout_rate=dropout_rate,
+ feed_forward_dropout_rate=dropout_rate,
+ )).to(DEVICE)
+
+ torch.manual_seed(SEED)
+ cust = CustomModel(CustomConfig(
+ num_lstm_layers=2,
+ num_ffn_layers=2,
+ )).to(DEVICE)
+
+ orig.load_state_dict(cust.state_dict()) # sync weights
+
+ if TORCH_COMPILE:
+ orig = torch.compile(orig)
+ cust = torch.compile(cust)
+
+ x = torch.randn(B, T, device=DEVICE)
+ paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE)
+
+ for mode in ('train', 'eval'):
+ getattr(orig, mode)()
+ getattr(cust, mode)()
+
+ torch.manual_seed(SEED)
+ y1, p1 = orig(x, paddings)
+
+ torch.manual_seed(SEED)
+ y2, p2 = cust(x, paddings, dropout_rate=dropout_rate)
+
+ assert_close(y1, y2, atol=0, rtol=0)
+ assert_close(p1, p2, atol=0, rtol=0)
+
+ if mode == 'eval': # one extra test: omit dropout at eval
+ torch.manual_seed(SEED)
+ y2, p2 = cust(x, paddings)
+ assert_close(y1, y2, atol=0, rtol=0)
+ assert_close(p1, p2, atol=0, rtol=0)
+
+ @parameterized.named_parameters(
+ dict(testcase_name=''),)
+ def test_default_dropout(self):
+ """Test default dropout_rate."""
+
+ torch.manual_seed(SEED)
+ orig = OriginalModel(OriginalConfig(num_lstm_layers=2,
+ num_ffn_layers=2)).to(DEVICE)
+ torch.manual_seed(SEED)
+ cust = CustomModel(CustomConfig(num_lstm_layers=2,
+ num_ffn_layers=2)).to(DEVICE)
+ orig.load_state_dict(cust.state_dict()) # sync weights
+
+ if TORCH_COMPILE:
+ orig = torch.compile(orig)
+ cust = torch.compile(cust)
+
+ x = torch.randn(B, T, device=DEVICE)
+ paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE)
+
+ for mode in ('train', 'eval'):
+ getattr(orig, mode)()
+ getattr(cust, mode)()
+ torch.manual_seed(SEED)
+ y1, p1 = orig(x, paddings)
+ torch.manual_seed(SEED)
+ y2, p2 = cust(x, paddings)
+ assert_close(y1, y2, atol=0, rtol=0)
+ assert_close(p1, p2, atol=0, rtol=0)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/tests/dropout_fix/ogbg_pytorch_jax/test_model_equivalence.py b/tests/dropout_fix/ogbg_pytorch_jax/test_model_equivalence.py
new file mode 100644
index 000000000..3b3feb680
--- /dev/null
+++ b/tests/dropout_fix/ogbg_pytorch_jax/test_model_equivalence.py
@@ -0,0 +1,118 @@
+"""
+Runs fwd pass with random graphs for OGBG GNN models and compares outputs.
+Run with:
+ python3 tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py
+"""
+
+import os
+import random
+
+from absl.testing import absltest
+from absl.testing import parameterized
+from jraph import GraphsTuple
+import numpy as np
+import torch
+from torch.testing import assert_close
+
+from algoperf.workloads.ogbg.ogbg_pytorch.models import GNN as OriginalModel
+from algoperf.workloads.ogbg.ogbg_pytorch.models_dropout import \
+ GNN as CustomModel
+
+B, N, E = 8, 20, 40 # graphs, nodes/graph, edges/graph
+NODE_FDIM, EDGE_FDIM = 9, 3 # expected feature dims
+DEVICE = 'cuda'
+
+os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
+torch.backends.cudnn.benchmark = False
+torch.backends.cudnn.deterministic = True
+torch.use_deterministic_algorithms(True)
+SEED = 1996
+
+
+def _rand_graph():
+ total_nodes, total_edges = B * N, B * E
+ nodes = torch.randn(total_nodes, NODE_FDIM, device=DEVICE)
+ edges = torch.randn(total_edges, EDGE_FDIM, device=DEVICE)
+ senders, receivers = [], []
+ for i in range(B):
+ offset = i * N
+ s = torch.randint(N, (E,), device=DEVICE) + offset
+ r = torch.randint(N, (E,), device=DEVICE) + offset
+ senders.append(s), receivers.append(r)
+ senders = torch.cat(senders)
+ receivers = torch.cat(receivers)
+ n_node = torch.full((B,), N, device=DEVICE, dtype=torch.int32)
+ n_edge = torch.full((B,), E, device=DEVICE, dtype=torch.int32)
+ return GraphsTuple(nodes, edges, receivers, senders, None, n_node, n_edge)
+
+
+class GNNEquivalenceTest(parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ dict(testcase_name='0.0', dropout_rate=0.0),
+ dict(testcase_name='0.2', dropout_rate=0.2),
+ dict(testcase_name='0.7', dropout_rate=0.7),
+ dict(testcase_name='1.0', dropout_rate=1.0),
+ )
+ def test_forward(self, dropout_rate):
+ """Test different dropout_rates."""
+
+ orig = OriginalModel(dropout_rate=dropout_rate).to(DEVICE)
+ cust = CustomModel().to(DEVICE)
+ orig.load_state_dict(cust.state_dict()) # sync weights
+
+ graph = _rand_graph()
+
+ for mode in ('train', 'eval'):
+ getattr(orig, mode)()
+ getattr(cust, mode)()
+
+ torch.manual_seed(SEED)
+ random.seed(SEED)
+ np.random.seed(SEED)
+ y1 = orig(graph)
+
+ torch.manual_seed(SEED)
+ random.seed(SEED)
+ np.random.seed(SEED)
+ y2 = cust(graph, dropout_rate=dropout_rate)
+
+ assert_close(y1, y2, atol=0, rtol=0)
+
+ if mode == 'eval': # one extra test: omit dropout at eval
+ torch.manual_seed(SEED)
+ random.seed(SEED)
+ np.random.seed(SEED)
+ y2 = cust(graph)
+ assert_close(y1, y2, atol=0, rtol=0)
+
+ @parameterized.named_parameters(
+ dict(testcase_name=''),)
+ def test_default_dropout(self):
+ """Test default dropout_rate."""
+
+ orig = OriginalModel().to(DEVICE)
+ cust = CustomModel().to(DEVICE)
+ orig.load_state_dict(cust.state_dict()) # sync weights
+
+ graph = _rand_graph()
+
+ for mode in ('train', 'eval'):
+ getattr(orig, mode)()
+ getattr(cust, mode)()
+
+ torch.manual_seed(SEED)
+ random.seed(SEED)
+ np.random.seed(SEED)
+ y1 = orig(graph)
+
+ torch.manual_seed(SEED)
+ random.seed(SEED)
+ np.random.seed(SEED)
+ y2 = cust(graph)
+
+ assert_close(y1, y2, atol=0, rtol=0)
+
+
+if __name__ == '__main__':
+ absltest.main()
diff --git a/tests/dropout_fix/wmt_pytorch_jax/test_model_equivalence.py b/tests/dropout_fix/wmt_pytorch_jax/test_model_equivalence.py
new file mode 100644
index 000000000..03f289a68
--- /dev/null
+++ b/tests/dropout_fix/wmt_pytorch_jax/test_model_equivalence.py
@@ -0,0 +1,121 @@
+"""
+Runs fwd pass with random input for WMT Transformer models and compares outputs.
+Run with:
+ python3 tests/dropout_fix/wmt_pytorch/test_model_equivalence.py
+"""
+
+import os
+import random
+
+from absl.testing import absltest
+from absl.testing import parameterized
+import numpy as np
+import torch
+from torch.testing import assert_close
+
+from algoperf.workloads.wmt.wmt_pytorch.models import \
+ Transformer as OriginalModel
+from algoperf.workloads.wmt.wmt_pytorch.models_dropout import \
+ Transformer as CustomModel
+
+B, SRC_LEN, TGT_LEN, NTOK = 16, 80, 80, 32_000
+DEVICE = "cuda"
+SEED = 1996
+
+os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
+torch.backends.cudnn.benchmark = False
+torch.backends.cudnn.deterministic = True
+torch.use_deterministic_algorithms(True)
+
+
+def _rand_tokens(bs, seqlen):
+ return torch.randint(1, NTOK, (bs, seqlen), device=DEVICE)
+
+
+class TransformerEquivalenceTest(parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ # NOTE: removed dropout=1.0 since it will generate nan in scaled_dot_product_attention
+ dict(testcase_name="0.0", dropout_rate=0.0, compile=False),
+ dict(testcase_name="0.2", dropout_rate=0.2, compile=False),
+ dict(testcase_name="0.7", dropout_rate=0.7, compile=False),
+ dict(testcase_name="p=0.0_compile", dropout_rate=0.0, compile=True),
+ dict(testcase_name="p=0.2_compile", dropout_rate=0.2, compile=True),
+ dict(testcase_name="p=0.7_compile", dropout_rate=0.7, compile=True),
+ )
+ def test_dropout_value(self, dropout_rate, compile):
+
+ orig = OriginalModel(
+ dropout_rate=dropout_rate,
+ attention_dropout_rate=dropout_rate).to(DEVICE)
+ cust = CustomModel().to(DEVICE)
+
+ orig.load_state_dict(cust.state_dict()) # sync weights
+
+ if compile:
+ orig = torch.compile(orig)
+ cust = torch.compile(cust)
+
+ src = _rand_tokens(B, SRC_LEN)
+ tgt = _rand_tokens(B, TGT_LEN)
+
+ for mode in ("train", "eval"):
+ getattr(orig, mode)()
+ getattr(cust, mode)()
+
+ torch.manual_seed(SEED)
+ random.seed(SEED)
+ np.random.seed(SEED)
+ y1 = orig(src, tgt)
+
+ torch.manual_seed(SEED)
+ random.seed(SEED)
+ np.random.seed(SEED)
+ y2 = cust(src, tgt, dropout_rate=dropout_rate)
+
+ assert_close(y1, y2, atol=0, rtol=0)
+
+ if mode == 'eval': # one extra test: omit dropout at eval
+ torch.manual_seed(SEED)
+ random.seed(SEED)
+ np.random.seed(SEED)
+ y2 = cust(src, tgt)
+ assert_close(y1, y2, atol=0, rtol=0)
+
+ @parameterized.named_parameters(
+ dict(testcase_name="default", compile=False),
+ dict(testcase_name="default_compile", compile=True),
+ )
+ def test_default(self, compile):
+
+ orig = OriginalModel().to(DEVICE)
+ cust = CustomModel().to(DEVICE)
+
+ orig.load_state_dict(cust.state_dict()) # sync weights
+
+ if compile:
+ orig = torch.compile(orig)
+ cust = torch.compile(cust)
+
+ src = _rand_tokens(B, SRC_LEN)
+ tgt = _rand_tokens(B, TGT_LEN)
+
+ for mode in ("train", "eval"):
+ getattr(orig, mode)()
+ getattr(cust, mode)()
+
+ torch.manual_seed(SEED)
+ random.seed(SEED)
+ np.random.seed(SEED)
+ y1 = orig(src, tgt)
+
+ torch.manual_seed(SEED)
+ random.seed(SEED)
+ np.random.seed(SEED)
+ y2 = cust(src, tgt)
+
+ assert_close(y1, y2, atol=0, rtol=0)
+
+
+if __name__ == "__main__":
+ absltest.main()
diff --git a/tests/test_jax_utils.py b/tests/test_jax_utils.py
new file mode 100644
index 000000000..04713915f
--- /dev/null
+++ b/tests/test_jax_utils.py
@@ -0,0 +1,237 @@
+"""
+Test algoperf.jax_utils.Dropout by comparing to flax.linen.Dropout
+Run it as: pytest
+"""
+
+import os
+
+from absl.testing import absltest
+from absl.testing import parameterized
+import jax
+import jax.numpy as jnp
+import flax.linen as nn
+
+from jax.tree_util import tree_structure, tree_leaves, tree_map
+from algoperf.jax_utils import Dropout
+from functools import partial
+
+
+SEED = 1996
+DEFAULT_DROPOUT = 0.5
+
+
+def pytrees_are_equal(a, b, rtol=1e-5, atol=1e-8):
+ """
+ A custom function to check if two PyTrees are equal, handling floats with a tolerance.
+ """
+ # 1. Check if the structures are the same
+ if tree_structure(a) != tree_structure(b):
+ return False
+
+ # 2. Define a comparison function for leaves
+ def leaf_comparator(x, y):
+ # Use allclose for floating-point JAX arrays
+ if isinstance(x, jnp.ndarray) and jnp.issubdtype(x.dtype, jnp.floating):
+ return jnp.allclose(x, y, rtol=rtol, atol=atol)
+ # Use standard equality for everything else
+ else:
+ return x == y
+
+ # 3. Map the comparison function over the trees and check if all results are True
+ # We also need to flatten the results of the tree_map and check if all are True
+ comparison_tree = tree_map(leaf_comparator, a, b)
+ all_equal = all(tree_leaves(comparison_tree))
+
+ return all_equal
+
+
+class LegacyDropoutModel(nn.Module):
+ dropout_rate: float = DEFAULT_DROPOUT
+
+ @nn.compact
+ def __call__(self, x, train):
+ return nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x)
+
+
+class DropoutModel(nn.Module):
+
+ @nn.compact
+ def __call__(self, x, train, dropout_rate=DEFAULT_DROPOUT):
+ return Dropout(
+ rate=dropout_rate, deterministic=not train)(
+ x, rate=dropout_rate)
+
+
+class ModelEquivalenceTest(parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ dict(
+ testcase_name="Dropout, p=0.0, train", dropout_rate=0.0,
+ mode="train"),
+ dict(testcase_name="Dropout, p=0.0, eval", dropout_rate=0.0, mode="eval"),
+ dict(
+ testcase_name="Dropout, p=0.1, train", dropout_rate=0.1,
+ mode="train"),
+ dict(testcase_name="Dropout, p=0.1, eval", dropout_rate=0.1, mode="eval"),
+ )
+ def test_forward(self, dropout_rate, mode):
+ """ Compare forward pass of Dropout layer to flax.linen.Dropout in train and
+ eval mode.
+ """
+
+ # initialize models
+ rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3)
+ fake_batch = jnp.ones((10,))
+ orig_model = LegacyDropoutModel(dropout_rate=dropout_rate)
+ cust_model = DropoutModel()
+
+ initial_variables_original = orig_model.init({"params": rng},
+ fake_batch,
+ train=False)
+ initial_variables_custom = cust_model.init({"params": rng},
+ fake_batch,
+ train=False)
+
+ assert pytrees_are_equal(
+ initial_variables_original, initial_variables_custom, rtol=1e-6)
+
+ # forward pass
+ x = jnp.ones((10,))
+
+ train = mode == "train"
+ y1 = orig_model.apply(
+ initial_variables_original,
+ x,
+ train=train,
+ rngs={"dropout": dropout_rng})
+ y2 = cust_model.apply(
+ initial_variables_custom,
+ x,
+ train=train,
+ dropout_rate=dropout_rate,
+ rngs={"dropout": dropout_rng},
+ )
+
+ assert jnp.allclose(y1, y2, atol=1e-3, rtol=1e-3)
+
+ @parameterized.named_parameters(
+ dict(
+ testcase_name="Dropout, p=0.0, train", dropout_rate=0.0,
+ mode="train"),
+ dict(testcase_name="Dropout, p=0.0, eval", dropout_rate=0.0, mode="eval"),
+ dict(
+ testcase_name="Dropout, p=0.1, train", dropout_rate=0.1,
+ mode="train"),
+ dict(testcase_name="Dropout, p=0.1, eval", dropout_rate=0.1, mode="eval"),
+ )
+ def test_dropout_update(self, dropout_rate, mode):
+ """Call forward pass of Dropout layer with two different dropout rates
+ and check that the output matches to flax.linen.Dropout in train and
+ eval mode.
+ """
+ # init model
+ rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3)
+ fake_batch = jnp.ones((10,))
+ orig_model = LegacyDropoutModel(dropout_rate=dropout_rate)
+ cust_model = DropoutModel()
+
+ initial_variables_original = orig_model.init({"params": rng},
+ fake_batch,
+ train=False)
+
+ initial_variables_custom = cust_model.init({"params": rng},
+ fake_batch,
+ train=False)
+
+ assert pytrees_are_equal(
+ initial_variables_original, initial_variables_custom, rtol=1e-6)
+
+ # forward pass
+ x = jnp.ones((10,))
+
+ train = mode == "train"
+ y1 = orig_model.apply(
+ initial_variables_original,
+ x,
+ train=train,
+ rngs={"dropout": dropout_rng})
+
+ _ = cust_model.apply(
+ initial_variables_custom,
+ x,
+ train=train,
+ dropout_rate=0.9,
+ rngs={"dropout": dropout_rng},
+ )
+
+ y2 = cust_model.apply(
+ initial_variables_custom,
+ x,
+ train=train,
+ dropout_rate=dropout_rate,
+ rngs={"dropout": dropout_rng},
+ )
+ assert jnp.allclose(y1, y2, atol=1e-3, rtol=1e-3)
+
+ @parameterized.named_parameters(
+ dict(
+ testcase_name="Dropout, p=0.0, train", dropout_rate=0.0,
+ mode="train"),
+ dict(testcase_name="Dropout, p=0.0, eval", dropout_rate=0.0, mode="eval"),
+ dict(
+ testcase_name="Dropout, p=0.1, train", dropout_rate=0.1,
+ mode="train"),
+ dict(testcase_name="Dropout, p=0.1, eval", dropout_rate=0.1, mode="eval"),
+ )
+ def test_jitted_updates(self, dropout_rate, mode):
+ """ Compare forward pass of Dropout layer to flax.linen.Dropout in train and
+ eval mode.
+ """
+
+ # initialize models
+ rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3)
+ fake_batch = jnp.ones((10,))
+ orig_model = LegacyDropoutModel(dropout_rate=dropout_rate)
+ cust_model = DropoutModel()
+
+ initial_variables_original = orig_model.init({"params": rng},
+ fake_batch,
+ train=False)
+ initial_variables_custom = cust_model.init({"params": rng},
+ fake_batch,
+ train=False)
+
+ assert pytrees_are_equal(
+ initial_variables_original, initial_variables_custom, rtol=1e-6)
+
+ # forward pass
+ x = jnp.ones((10,))
+
+ train = mode == "train"
+ jitted_original_apply = jax.jit(
+ partial(orig_model.apply), static_argnames=['train'])
+ jitted_custom_apply = jax.jit(
+ partial(cust_model.apply), static_argnames=['train'])
+
+
+ def multiple_fwd_passes_custom_layer():
+ for d in [i * 0.1 * dropout_rate for i in range(0, 11)]:
+ y2 = jitted_custom_apply(
+ initial_variables_custom,
+ x,
+ train=train,
+ dropout_rate=d,
+ rngs={"dropout": dropout_rng},
+ )
+ return y2
+
+ def multiple_fwd_passes_original_layer():
+ for d in [i * 0.1 * dropout_rate for i in range(0, 11)]:
+ y1 = jitted_original_apply(
+ initial_variables_original,
+ x,
+ train=train,
+ rngs={"dropout": dropout_rng})
+
+if __name__ == "__main__":
+ absltest.main()
From d3f25d8ec082f37541a92f340e41cbb0b9a6b269 Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Sat, 21 Jun 2025 08:05:03 +0000
Subject: [PATCH 082/123] add tests
---
.../criteo1tb/criteo1tb_jax/models_ref.py | 2 +-
.../librispeech_jax/models_ref.py | 6 +-
algoperf/workloads/ogbg/ogbg_jax/models.py | 2 +-
.../workloads/ogbg/ogbg_jax/models_ref.py | 16 +-
.../fastmri_jax/test_model_equivalence.py | 174 ++++++++--------
.../test_model_equivalence.py | 188 ++++++++----------
.../test_model_equivalence.py | 147 ++++++++------
.../test_model_equivalence.py | 161 +++++++--------
.../test_model_equivalence.py | 118 -----------
.../wmt_pytorch_jax/test_model_equivalence.py | 121 -----------
10 files changed, 337 insertions(+), 598 deletions(-)
delete mode 100644 tests/dropout_fix/ogbg_pytorch_jax/test_model_equivalence.py
delete mode 100644 tests/dropout_fix/wmt_pytorch_jax/test_model_equivalence.py
diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/models_ref.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/models_ref.py
index 8406b9eb1..d0352e290 100644
--- a/algoperf/workloads/criteo1tb/criteo1tb_jax/models_ref.py
+++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/models_ref.py
@@ -23,7 +23,7 @@ class DLRMResNet(nn.Module):
mlp_bottom_dims: Sequence[int] = (256, 256, 256)
mlp_top_dims: Sequence[int] = (256, 256, 256, 256, 1)
embed_dim: int = 128
- dropout_rate: float = 0.0
+ dropout_rate: float = 0.1
use_layer_norm: bool = False # Unused.
embedding_init_multiplier: float = None # Unused
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/models_ref.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/models_ref.py
index 969d9423c..f168833d3 100644
--- a/algoperf/workloads/librispeech_conformer/librispeech_jax/models_ref.py
+++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/models_ref.py
@@ -36,12 +36,12 @@ class ConformerConfig:
encoder_dim: int = 512
num_attention_heads: int = 8
num_encoder_layers: int = 4
- attention_dropout_rate: float = 0.0
+ attention_dropout_rate: float = 0.1
# If None, defaults to 0.1.
attention_residual_dropout_rate: Optional[float] = 0.1
# If None, defaults to 0.0.
- conv_residual_dropout_rate: Optional[float] = 0.0
- feed_forward_dropout_rate: float = 0.0
+ conv_residual_dropout_rate: Optional[float] = 0.1
+ feed_forward_dropout_rate: float = 0.1
# If None, defaults to 0.1.
feed_forward_residual_dropout_rate: Optional[float] = 0.1
convolution_kernel_size: int = 5
diff --git a/algoperf/workloads/ogbg/ogbg_jax/models.py b/algoperf/workloads/ogbg/ogbg_jax/models.py
index 8524bb60e..7207e033d 100644
--- a/algoperf/workloads/ogbg/ogbg_jax/models.py
+++ b/algoperf/workloads/ogbg/ogbg_jax/models.py
@@ -93,4 +93,4 @@ def __call__(self, graph, train, dropout_rate=DROPOUT_RATE):
decoder = jraph.GraphMapFeatures(embed_global_fn=nn.Dense(self.num_outputs))
graph = decoder(graph)
- return graph.globals
+ return graph.globals
\ No newline at end of file
diff --git a/algoperf/workloads/ogbg/ogbg_jax/models_ref.py b/algoperf/workloads/ogbg/ogbg_jax/models_ref.py
index f0a9e3dc1..ca3d89426 100644
--- a/algoperf/workloads/ogbg/ogbg_jax/models_ref.py
+++ b/algoperf/workloads/ogbg/ogbg_jax/models_ref.py
@@ -15,7 +15,7 @@ def make_fn(inputs):
return make_fn
-def _make_mlp(hidden_dims, dropout, activation_fn):
+def _make_mlp(hidden_dims, activation_fn, train, dropout_rate):
"""Creates a MLP with specified dimensions."""
@jraph.concatenated_args
@@ -25,7 +25,7 @@ def make_fn(inputs):
x = nn.Dense(features=dim)(x)
x = nn.LayerNorm()(x)
x = activation_fn(x)
- x = dropout(x)
+ x = nn.Dropout(rate=dropout_rate, deterministic=not train)(x)
return x
return make_fn
@@ -46,11 +46,7 @@ class GNN(nn.Module):
@nn.compact
def __call__(self, graph, train):
- if self.dropout_rate is None:
- dropout_rate = 0.1
- else:
- dropout_rate = self.dropout_rate
- dropout = nn.Dropout(rate=dropout_rate, deterministic=not train)
+ dropout_rate = self.dropout_rate
graph = graph._replace(
globals=jnp.zeros([graph.n_node.shape[0], self.num_outputs]))
@@ -73,11 +69,11 @@ def __call__(self, graph, train):
for _ in range(self.num_message_passing_steps):
net = jraph.GraphNetwork(
update_edge_fn=_make_mlp(
- self.hidden_dims, dropout=dropout, activation_fn=activation_fn),
+ self.hidden_dims, activation_fn=activation_fn, train=train, dropout_rate=dropout_rate),
update_node_fn=_make_mlp(
- self.hidden_dims, dropout=dropout, activation_fn=activation_fn),
+ self.hidden_dims, activation_fn=activation_fn, train=train, dropout_rate=dropout_rate),
update_global_fn=_make_mlp(
- self.hidden_dims, dropout=dropout, activation_fn=activation_fn))
+ self.hidden_dims, activation_fn=activation_fn, train=train, dropout_rate=dropout_rate))
graph = net(graph)
diff --git a/tests/dropout_fix/fastmri_jax/test_model_equivalence.py b/tests/dropout_fix/fastmri_jax/test_model_equivalence.py
index 1d318e8c6..fbb6d2499 100644
--- a/tests/dropout_fix/fastmri_jax/test_model_equivalence.py
+++ b/tests/dropout_fix/fastmri_jax/test_model_equivalence.py
@@ -6,125 +6,113 @@
import os
+
from absl.testing import absltest
from absl.testing import parameterized
-import torch
-from torch.testing import assert_close
+import jax
+import jax.numpy as jnp
+# import equinox as eqx
+
-from algoperf.workloads.fastmri.fastmri_pytorch.models import \
+from algoperf.workloads.fastmri.fastmri_jax.models_ref import \
UNet as OriginalUNet
-from algoperf.workloads.fastmri.fastmri_pytorch.models_dropout import \
+from algoperf.workloads.fastmri.fastmri_jax.models import \
UNet as CustomUNet
BATCH, IN_CHANS, H, W = 4, 1, 256, 256
OUT_CHANS, C, LAYERS = 1, 32, 4
-DEVICE = 'cuda'
-TORCH_COMPILE = False
SEED = 1996
-os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
-torch.backends.cudnn.benchmark = False
-torch.backends.cudnn.deterministic = True
-torch.use_deterministic_algorithms(True)
-
-class FastMRIModeEquivalenceTest(parameterized.TestCase):
-
- def fwd_pass(self, orig, cust, dropout_rate):
- x = torch.randn(BATCH, IN_CHANS, H, W, device=DEVICE)
- for mode in ('train', 'eval'):
- getattr(orig, mode)()
- getattr(cust, mode)()
- torch.manual_seed(0)
- y1 = orig(x)
- torch.manual_seed(0)
- y2 = cust(x, dropout_rate)
- assert_close(y1, y2, atol=0, rtol=0)
- if mode == 'eval': # one extra test: omit dropout at eval
- torch.manual_seed(0)
- y2 = cust(x)
- assert_close(y1, y2, atol=0, rtol=0)
+class ModelEquivalenceTest(parameterized.TestCase):
@parameterized.named_parameters(
- dict(testcase_name='p=0.0', dropout_rate=0.0),
- dict(testcase_name='p=0.1', dropout_rate=0.1),
- dict(testcase_name='p=0.7', dropout_rate=0.7),
- dict(testcase_name='p=1.0', dropout_rate=1.0),
+ dict(
+ testcase_name='UNet, p=0.0',
+ dropout_rate=0.0),
+ dict(
+ testcase_name='UNet, p=0.1',
+ dropout_rate=0.1),
)
- def test_dropout_values(self, dropout_rate):
- """Test different values of dropout_rate."""
+ def test_forward(self, dropout_rate):
+ OrigCls, CustCls = (OriginalUNet, CustomUNet)
- torch.manual_seed(SEED)
- orig = OriginalUNet(
- IN_CHANS, OUT_CHANS, C, LAYERS, dropout_rate=dropout_rate).to(DEVICE)
- torch.manual_seed(SEED)
- cust = CustomUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE)
+ # init model
+ rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3)
- cust.load_state_dict(orig.state_dict()) # sync weights
- if TORCH_COMPILE:
- orig = torch.compile(orig)
- cust = torch.compile(cust)
+ kwargs = dict(num_pool_layers = LAYERS, num_channels=IN_CHANS)
+ orig_model = OrigCls(**kwargs)
+ cust_model = CustCls(**kwargs)
- self.fwd_pass(orig, cust, dropout_rate)
+ fake_batch = jnp.ones((BATCH, IN_CHANS, H, W))
- @parameterized.named_parameters(
- dict(testcase_name='default', use_tanh=False, use_layer_norm=False),
- dict(testcase_name='tanh', use_tanh=True, use_layer_norm=False),
- dict(testcase_name='layer_norm', use_tanh=False, use_layer_norm=True),
- dict(testcase_name='both', use_tanh=True, use_layer_norm=True),
- )
- def test_arch_configs(self, use_tanh, use_layer_norm):
- """Test different architecture configurations, fixed dropout_rate."""
- dropout_rate = 0.1
-
- torch.manual_seed(SEED)
- orig = OriginalUNet(
- IN_CHANS,
- OUT_CHANS,
- C,
- LAYERS,
- dropout_rate=dropout_rate,
- use_tanh=use_tanh,
- use_layer_norm=use_layer_norm).to(DEVICE)
-
- torch.manual_seed(SEED)
- cust = CustomUNet(
- IN_CHANS,
- OUT_CHANS,
- C,
- LAYERS,
- use_tanh=use_tanh,
- use_layer_norm=use_layer_norm).to(DEVICE)
-
- cust.load_state_dict(orig.state_dict()) # sync weights
- if TORCH_COMPILE:
- orig = torch.compile(orig)
- cust = torch.compile(cust)
-
- self.fwd_pass(orig, cust, dropout_rate)
+ initial_params_original = orig_model.init({'params': rng},
+ fake_batch,
+ train=False)
+ initial_params_custom = cust_model.init({'params': rng},
+ fake_batch,
+ train=False)
+
+ # fwd
+ x = jax.random.normal(data_rng, shape=(BATCH, H, W))
+
+ for mode in ('train', 'eval'):
+ train = mode == 'train'
+ y1 = orig_model.apply(
+ initial_params_original,
+ x,
+ train=train,
+ rngs={'dropout': dropout_rng})
+ y2 = cust_model.apply(
+ initial_params_custom,
+ x,
+ train=train,
+ dropout_rate=dropout_rate,
+ rngs={'dropout': dropout_rng})
+
+ assert jnp.allclose(y1, y2, atol=1e-3, rtol=1e-3)
@parameterized.named_parameters(
- dict(testcase_name=''),)
+ dict(testcase_name='UNet, default'),
+ )
def test_default_dropout(self):
"""Test default dropout_rate."""
+ OrigCls, CustCls = (OriginalUNet, CustomUNet)
- torch.manual_seed(SEED)
- orig = OriginalUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE)
- torch.manual_seed(SEED)
- cust = CustomUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE)
- cust.load_state_dict(orig.state_dict()) # sync weights
- x = torch.randn(BATCH, IN_CHANS, H, W, device=DEVICE)
- for mode in ('train', 'eval'):
- getattr(orig, mode)()
- getattr(cust, mode)()
- torch.manual_seed(0)
- y1 = orig(x)
- torch.manual_seed(0)
- y2 = cust(x)
- assert_close(y1, y2, atol=0, rtol=0)
+ # init model
+ rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3)
+
+ kwargs = dict(num_pool_layers=LAYERS,
+ num_channels=IN_CHANS,
+ )
+ orig_model = OrigCls(**kwargs)
+ cust_model = CustCls(**kwargs)
+ fake_batch = jnp.ones((2, IN_CHANS, H, W))
+
+ initial_params_original = orig_model.init({'params': rng},
+ fake_batch,
+ train=False)
+ initial_params_custom = cust_model.init({'params': rng},
+ fake_batch,
+ train=False)
+
+ # fwd
+ x = jax.random.normal(data_rng, shape=(BATCH, H, W))
+
+ for mode in ('train', 'eval'):
+ train = mode == 'train'
+ y1 = orig_model.apply(
+ initial_params_original,
+ x,
+ train=train,
+ rngs={'dropout': dropout_rng})
+ y2 = cust_model.apply(
+ initial_params_custom, x, train=train, rngs={'dropout': dropout_rng})
+
+ assert jnp.allclose(y1, y2, atol=0, rtol=0)
if __name__ == '__main__':
absltest.main()
diff --git a/tests/dropout_fix/imagenet_vit_jax/test_model_equivalence.py b/tests/dropout_fix/imagenet_vit_jax/test_model_equivalence.py
index f51eaec7e..6806ca99f 100644
--- a/tests/dropout_fix/imagenet_vit_jax/test_model_equivalence.py
+++ b/tests/dropout_fix/imagenet_vit_jax/test_model_equivalence.py
@@ -1,138 +1,110 @@
"""
Runs fwd pass with random input for FASTMRI U-Net models and compares outputs.
Run it as:
- python3 tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py
+ python3 tests/dropout_fix/imagenet_vit_jax/test_model_equivalence.py
"""
-import itertools
import os
from absl.testing import absltest
from absl.testing import parameterized
-import torch
-from torch.testing import assert_close
+import jax
+import jax.numpy as jnp
-from algoperf.workloads.imagenet_vit.imagenet_pytorch.models import \
+from algoperf.workloads.imagenet_vit.imagenet_jax.models_ref import \
ViT as OriginalVit
-from algoperf.workloads.imagenet_vit.imagenet_pytorch.models_dropout import \
+from algoperf.workloads.imagenet_vit.imagenet_jax.models import \
ViT as CustomVit
# Model / test hyper-params
-BATCH, C, H, W = 4, 3, 224, 224 # input shape (N,C,H,W)
-WIDTH, DEPTH, HEADS = 256, 4, 8
-DROPOUT_RATE = None
-DEVICE = 'cuda'
-SEED = 1996
+INPUT_SHAPE = (2, 224, 124, 3)
+SEED = 1994
-os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
-torch.backends.cudnn.benchmark = False
-torch.backends.cudnn.deterministic = True
-torch.use_deterministic_algorithms(True)
+class ImageNetVitModeEquivalenceTest(parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ dict(
+ testcase_name='ViT, p=0.0',
+ dropout_rate=0.0),
+ dict(
+ testcase_name='ViT, p=0.1',
+ dropout_rate=0.1),
+ )
+ def test_forward(self, dropout_rate):
+ OrigCls, CustCls = (OriginalVit, CustomVit)
-class ImageNetVitModeEquivalenceTest(parameterized.TestCase):
+ # init model
+ rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3)
+
+ orig_model = OrigCls()
+ cust_model = CustCls()
+
+ fake_batch = jnp.ones(INPUT_SHAPE)
+
+ initial_params_original = orig_model.init({'params': rng},
+ fake_batch,
+ train=False)
+ initial_params_custom = cust_model.init({'params': rng},
+ fake_batch,
+ train=False)
+
+ # fwd
+ x = jax.random.normal(data_rng, shape=INPUT_SHAPE)
- def fwd_pass(self, orig, cust, dropout_rate):
- x = torch.randn(BATCH, C, H, W, device=DEVICE)
for mode in ('train', 'eval'):
- getattr(orig, mode)()
- getattr(cust, mode)()
- torch.manual_seed(0)
- y1 = orig(x)
- torch.manual_seed(0)
- y2 = cust(x, dropout_rate)
- assert_close(y1, y2, atol=0, rtol=0)
- if mode == 'eval': # one extra test: omit dropout at eval
- torch.manual_seed(0)
- y2 = cust(x, dropout_rate)
- assert_close(y1, y2, atol=0, rtol=0)
+ train = mode == 'train'
+ y1 = orig_model.apply(
+ initial_params_original,
+ x,
+ train=train,
+ rngs={'dropout': dropout_rng})
+ y2 = cust_model.apply(
+ initial_params_custom,
+ x,
+ train=train,
+ dropout_rate=dropout_rate,
+ rngs={'dropout': dropout_rng})
+
+ assert jnp.allclose(y1, y2, atol=1e-3, rtol=1e-3)
@parameterized.named_parameters(
- dict(testcase_name='p=0.0', dropout_rate=0.0),
- dict(testcase_name='p=0.1', dropout_rate=0.1),
- dict(testcase_name='p=0.6', dropout_rate=0.6),
- dict(testcase_name='p=1.0', dropout_rate=1.0),
+ dict(testcase_name='UNet, default'),
)
- def test_dropout_values(self, dropout_rate):
- """Test different dropout_rates."""
-
- torch.manual_seed(SEED)
- orig = OriginalVit(
- width=WIDTH,
- depth=DEPTH,
- num_heads=HEADS,
- dropout_rate=dropout_rate,
- ).to(DEVICE)
-
- torch.manual_seed(SEED)
- cust = CustomVit(
- width=WIDTH,
- depth=DEPTH,
- num_heads=HEADS,
- ).to(DEVICE)
-
- cust.load_state_dict(orig.state_dict()) # sync weights
- self.fwd_pass(orig, cust, dropout_rate)
-
- @parameterized.named_parameters([
- dict(
- testcase_name=f"GLU={use_glu}_LN={use_post_ln}_MAP={use_map}",
- use_glu=use_glu,
- use_post_ln=use_post_ln,
- use_map=use_map,
- ) for use_glu,
- use_post_ln,
- use_map in itertools.product([False, True], repeat=3)
- ])
- def test_arch(self, use_glu, use_post_ln, use_map):
- """Test different architecture configurations, fixed dropout_rate."""
- dropout_rate = 0.1
-
- torch.manual_seed(SEED)
- orig = OriginalVit(
- width=WIDTH,
- depth=DEPTH,
- num_heads=HEADS,
- use_glu=use_glu,
- use_post_layer_norm=use_post_ln,
- use_map=use_map,
- dropout_rate=dropout_rate,
- ).to(DEVICE)
-
- torch.manual_seed(SEED)
- cust = CustomVit(
- width=WIDTH,
- depth=DEPTH,
- num_heads=HEADS,
- use_glu=use_glu,
- use_post_layer_norm=use_post_ln,
- use_map=use_map,
- ).to(DEVICE)
-
- cust.load_state_dict(orig.state_dict()) # sync weights
- self.fwd_pass(orig, cust, dropout_rate)
-
- @parameterized.named_parameters(
- dict(testcase_name=''),)
def test_default_dropout(self):
"""Test default dropout_rate."""
+ OrigCls, CustCls = (OriginalVit, CustomVit)
+
+
+ # init model
+ rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3)
+
+ orig_model = OrigCls()
+ cust_model = CustCls()
+
+ fake_batch = jnp.ones(INPUT_SHAPE)
+
+ initial_params_original = orig_model.init({'params': rng},
+ fake_batch,
+ train=False)
+ initial_params_custom = cust_model.init({'params': rng},
+ fake_batch,
+ train=False)
- torch.manual_seed(SEED)
- orig = OriginalVit(width=WIDTH, depth=DEPTH, num_heads=HEADS).to(DEVICE)
- torch.manual_seed(SEED)
- cust = CustomVit(width=WIDTH, depth=DEPTH, num_heads=HEADS).to(DEVICE)
- cust.load_state_dict(orig.state_dict()) # sync weights
+ # fwd
+ x = jax.random.normal(data_rng, INPUT_SHAPE)
- x = torch.randn(BATCH, C, H, W, device=DEVICE)
for mode in ('train', 'eval'):
- getattr(orig, mode)()
- getattr(cust, mode)()
- torch.manual_seed(0)
- y1 = orig(x)
- torch.manual_seed(0)
- y2 = cust(x)
- assert_close(y1, y2, atol=0, rtol=0)
+ train = mode == 'train'
+ y1 = orig_model.apply(
+ initial_params_original,
+ x,
+ train=train,
+ rngs={'dropout': dropout_rng})
+ y2 = cust_model.apply(
+ initial_params_custom, x, train=train, rngs={'dropout': dropout_rng})
+ assert jnp.allclose(y1, y2, atol=0, rtol=0)
if __name__ == '__main__':
- absltest.main()
+ absltest.main()
\ No newline at end of file
diff --git a/tests/dropout_fix/librispeech_conformer_jax/test_model_equivalence.py b/tests/dropout_fix/librispeech_conformer_jax/test_model_equivalence.py
index 02f3a3d84..c5d849dce 100644
--- a/tests/dropout_fix/librispeech_conformer_jax/test_model_equivalence.py
+++ b/tests/dropout_fix/librispeech_conformer_jax/test_model_equivalence.py
@@ -1,84 +1,117 @@
"""
-Runs fwd pass with random input for LIBRISPEECH Conformer models and compares outputs.
-Run with:
- python3 tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py
-
-NOTE: we don't test for default dropout_rate values, since they changed.
+Runs fwd pass with random input for FASTMRI U-Net models and compares outputs.
+Run it as:
+ python3 tests/dropout_fix/imagenet_vit_jax/test_model_equivalence.py
"""
import os
from absl.testing import absltest
from absl.testing import parameterized
-import torch
-from torch.testing import assert_close
+import jax
+import jax.numpy as jnp
-from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \
- ConformerConfig as OriginalConfig
-from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \
- ConformerEncoderDecoder as OriginalModel
-from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models_dropout import \
- ConformerConfig as CustomConfig
-from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models_dropout import \
- ConformerEncoderDecoder as CustomModel
+from algoperf.workloads.librispeech_conformer.librispeech_jax.models import ConformerConfig as CustClsConfig
+from algoperf.workloads.librispeech_conformer.librispeech_jax.models import Conformer as CustCls
-N_LAYERS = 3
-B, T = 32, 36_000
-DEVICE = 'cuda'
+from algoperf.workloads.librispeech_conformer.librispeech_jax.models_ref import ConformerConfig as OrigClsConfig
+from algoperf.workloads.librispeech_conformer.librispeech_jax.models_ref import Conformer as OrigCls
-os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
-torch.backends.cudnn.benchmark = False
-torch.backends.cudnn.deterministic = True
-torch.use_deterministic_algorithms(mode=True)
-SEED = 1996
+# Model / test hyper-params
+INPUT_SHAPE = [(3200,), (3200,)]
+SEED = 1994
-class ConformerEquivalenceTest(parameterized.TestCase):
+class ModeEquivalenceTest(parameterized.TestCase):
@parameterized.named_parameters(
- dict(testcase_name='p=0.0', dropout_rate=0.0),
- dict(testcase_name='p=0.2', dropout_rate=0.2),
- dict(testcase_name='p=0.7', dropout_rate=0.7),
- dict(testcase_name='p=1.0', dropout_rate=1.0),
+ dict(
+ testcase_name='Conformer, p=0.0',
+ dropout_rate=0.0),
+ dict(
+ testcase_name='Conformer, p=0.1',
+ dropout_rate=0.1),
)
def test_forward(self, dropout_rate):
- torch.manual_seed(SEED)
- orig = OriginalModel(
- OriginalConfig(
- num_encoder_layers=N_LAYERS,
- attention_residual_dropout_rate=dropout_rate,
- conv_residual_dropout_rate=dropout_rate,
- feed_forward_residual_dropout_rate=dropout_rate,
- input_dropout_rate=dropout_rate,
- )).to(DEVICE)
+ # init model
+ rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3)
- torch.manual_seed(SEED)
- cust = CustomModel(CustomConfig(num_encoder_layers=N_LAYERS)).to(DEVICE)
+ orig_model = OrigCls(OrigClsConfig(attention_dropout_rate=dropout_rate,
+ attention_residual_dropout_rate = dropout_rate,
+ conv_residual_dropout_rate = dropout_rate,
+ feed_forward_dropout_rate = dropout_rate,
+ feed_forward_residual_dropout_rate = dropout_rate,
+ input_dropout_rate=dropout_rate))
+ cust_model = CustCls(CustClsConfig())
- orig.load_state_dict(cust.state_dict()) # sync weights
+ fake_batch = [jnp.zeros((2, *x), jnp.float32) for x in INPUT_SHAPE]
- x = torch.randn(B, T, device=DEVICE)
- paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE)
+ initial_params_original = orig_model.init({'params': rng},
+ *fake_batch,
+ train=False)
+ initial_params_custom = cust_model.init({'params': rng},
+ *fake_batch,
+ train=False)
- for mode in ('train', 'eval'):
- getattr(orig, mode)()
- getattr(cust, mode)()
+ # fwd
+ x = [jax.random.normal(data_rng, (2, *x)) for x in INPUT_SHAPE]
+
+ for train in [True]:
+ (y1, _), _ = orig_model.apply(
+ initial_params_original,
+ *x,
+ train=train,
+ rngs={'dropout': dropout_rng},
+ mutable=['batch_stats'],)
+ (y2, _), _ = cust_model.apply(
+ initial_params_custom,
+ *x,
+ train=train,
+ dropout_rate=dropout_rate,
+ rngs={'dropout': dropout_rng},
+ mutable=['batch_stats'])
+
+ assert jnp.allclose(y1, y2)
- torch.manual_seed(SEED)
- y1, p1 = orig(x, paddings)
- torch.manual_seed(SEED)
- y2, p2 = cust(x, paddings, dropout_rate=dropout_rate)
- assert_close(y1, y2, atol=0, rtol=0)
- assert_close(p1, p2, atol=0, rtol=0)
- if mode == 'eval': # one extra test: omit dropout at eval
- torch.manual_seed(SEED)
- y2, p2 = cust(x, paddings)
- assert_close(y1, y2, atol=0, rtol=0)
- assert_close(p1, p2, atol=0, rtol=0)
+ @parameterized.named_parameters(
+ dict(testcase_name='Conformer, default'),
+ )
+ def test_default_dropout(self):
+ """Test default dropout_rate."""
+ # init model
+ rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3)
+
+ orig_model = OrigCls(OrigClsConfig())
+ cust_model = CustCls(CustClsConfig())
+
+ fake_batch = [jnp.zeros((2, *x), jnp.float32) for x in INPUT_SHAPE]
+
+ initial_params_original = orig_model.init({'params': rng},
+ *fake_batch,
+ train=False)
+ initial_params_custom = cust_model.init({'params': rng},
+ *fake_batch,
+ train=False)
+
+ # fwd
+ x = [jax.random.normal(data_rng, (2, *x)) for x in INPUT_SHAPE]
+
+ for mode in ('train', 'eval'):
+ train = mode == 'train'
+ (y1, _), _ = orig_model.apply(
+ initial_params_original,
+ *x,
+ train=train,
+ rngs={'dropout': dropout_rng}, mutable=['batch_stats'])
+ (y2, _), _ = cust_model.apply(
+ initial_params_custom, *x, train=train, rngs={'dropout': dropout_rng}, mutable=['batch_stats'])
+
+
+ assert jnp.allclose(y1, y2)
if __name__ == '__main__':
- absltest.main()
+ absltest.main()
\ No newline at end of file
diff --git a/tests/dropout_fix/librispeech_deepspeech_jax/test_model_equivalence.py b/tests/dropout_fix/librispeech_deepspeech_jax/test_model_equivalence.py
index 7d6a94592..5155199f5 100644
--- a/tests/dropout_fix/librispeech_deepspeech_jax/test_model_equivalence.py
+++ b/tests/dropout_fix/librispeech_deepspeech_jax/test_model_equivalence.py
@@ -1,124 +1,113 @@
"""
-Runs fwd pass with random input for LIBRISPEECH Deepspeech models and compares outputs.
-Run with:
- python3 tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py
-
-`dropout_rate` controls the following args:
-- `input_dropout_rate` (if None, 0.1
-- `feed_forward_dropout_rate` (if None, 0.1)
+Runs fwd pass with random input for FASTMRI U-Net models and compares outputs.
+Run it as:
+ python3 tests/dropout_fix/imagenet_vit_jax/test_model_equivalence.py
"""
import os
from absl.testing import absltest
from absl.testing import parameterized
-import torch
-from torch.testing import assert_close
+import jax
+import jax.numpy as jnp
-from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import \
- DeepspeechConfig as OriginalConfig
-from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import \
- DeepspeechEncoderDecoder as OriginalModel
-from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models_dropout import \
- DeepspeechConfig as CustomConfig
-from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models_dropout import \
- DeepspeechEncoderDecoder as CustomModel
+from algoperf.workloads.librispeech_deepspeech.librispeech_jax.models import DeepspeechConfig as CustClsConfig
+from algoperf.workloads.librispeech_deepspeech.librispeech_jax.models import Deepspeech as CustCls
-B, T = 32, 30_000
-DEVICE = 'cuda'
-TORCH_COMPILE = False
+from algoperf.workloads.librispeech_deepspeech.librispeech_jax.models_ref import DeepspeechConfig as OrigClsConfig
+from algoperf.workloads.librispeech_deepspeech.librispeech_jax.models_ref import Deepspeech as OrigCls
-os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
-torch.backends.cudnn.benchmark = False
-torch.backends.cudnn.deterministic = True
-torch.use_deterministic_algorithms(mode=True)
-SEED = 1996
+# Model / test hyper-params
+INPUT_SHAPE = [(3200,), (3200,)]
+SEED = 1994
-class DeepSpeechEquivalenceTest(parameterized.TestCase):
+class ModeEquivalenceTest(parameterized.TestCase):
@parameterized.named_parameters(
- dict(testcase_name='p=0.0', dropout_rate=0.0),
- dict(testcase_name='p=0.2', dropout_rate=0.2),
- dict(testcase_name='p=0.7', dropout_rate=0.7),
- dict(testcase_name='p=1.0', dropout_rate=1.0),
+ dict(
+ testcase_name='Conformer, p=0.0',
+ dropout_rate=0.0),
+ dict(
+ testcase_name='Conformer, p=0.1',
+ dropout_rate=0.1),
)
def test_forward(self, dropout_rate):
- """Test different dropout_rate values."""
- torch.manual_seed(SEED)
- orig = OriginalModel(
- OriginalConfig(
- num_lstm_layers=2,
- num_ffn_layers=2,
- input_dropout_rate=dropout_rate,
- feed_forward_dropout_rate=dropout_rate,
- )).to(DEVICE)
+ # init model
+ rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3)
- torch.manual_seed(SEED)
- cust = CustomModel(CustomConfig(
- num_lstm_layers=2,
- num_ffn_layers=2,
- )).to(DEVICE)
+ orig_model = OrigCls(OrigClsConfig)
+ cust_model = CustCls(CustClsConfig)
- orig.load_state_dict(cust.state_dict()) # sync weights
+ fake_batch = [jnp.zeros((2, *x), jnp.float32) for x in INPUT_SHAPE]
- if TORCH_COMPILE:
- orig = torch.compile(orig)
- cust = torch.compile(cust)
+ initial_params_original = orig_model.init({'params': rng},
+ *fake_batch,
+ train=False)
+ initial_params_custom = cust_model.init({'params': rng},
+ *fake_batch,
+ train=False)
- x = torch.randn(B, T, device=DEVICE)
- paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE)
+ # fwd
+ x = [jax.random.normal(data_rng, (2, *x)) for x in INPUT_SHAPE]
for mode in ('train', 'eval'):
- getattr(orig, mode)()
- getattr(cust, mode)()
-
- torch.manual_seed(SEED)
- y1, p1 = orig(x, paddings)
+ train = mode == 'train'
+ (y1, _), _ = orig_model.apply(
+ initial_params_original,
+ *x,
+ train=train,
+ rngs={'dropout': dropout_rng},
+ mutable=['batch_stats'],)
+ (y2, _), _ = cust_model.apply(
+ initial_params_custom,
+ *x,
+ train=train,
+ dropout_rate=dropout_rate,
+ rngs={'dropout': dropout_rng},
+ mutable=['batch_stats'])
- torch.manual_seed(SEED)
- y2, p2 = cust(x, paddings, dropout_rate=dropout_rate)
+ assert jnp.allclose(y1, y2)
- assert_close(y1, y2, atol=0, rtol=0)
- assert_close(p1, p2, atol=0, rtol=0)
- if mode == 'eval': # one extra test: omit dropout at eval
- torch.manual_seed(SEED)
- y2, p2 = cust(x, paddings)
- assert_close(y1, y2, atol=0, rtol=0)
- assert_close(p1, p2, atol=0, rtol=0)
@parameterized.named_parameters(
- dict(testcase_name=''),)
+ dict(testcase_name='Conformer, default'),
+ )
def test_default_dropout(self):
"""Test default dropout_rate."""
+ # init model
+ rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3)
- torch.manual_seed(SEED)
- orig = OriginalModel(OriginalConfig(num_lstm_layers=2,
- num_ffn_layers=2)).to(DEVICE)
- torch.manual_seed(SEED)
- cust = CustomModel(CustomConfig(num_lstm_layers=2,
- num_ffn_layers=2)).to(DEVICE)
- orig.load_state_dict(cust.state_dict()) # sync weights
+ orig_model = OrigCls(OrigClsConfig)
+ cust_model = CustCls(CustClsConfig)
- if TORCH_COMPILE:
- orig = torch.compile(orig)
- cust = torch.compile(cust)
+ fake_batch = [jnp.zeros((2, *x), jnp.float32) for x in INPUT_SHAPE]
- x = torch.randn(B, T, device=DEVICE)
- paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE)
+ initial_params_original = orig_model.init({'params': rng},
+ *fake_batch,
+ train=False)
+ initial_params_custom = cust_model.init({'params': rng},
+ *fake_batch,
+ train=False)
+
+ # fwd
+ x = [jax.random.normal(data_rng, (2, *x)) for x in INPUT_SHAPE]
for mode in ('train', 'eval'):
- getattr(orig, mode)()
- getattr(cust, mode)()
- torch.manual_seed(SEED)
- y1, p1 = orig(x, paddings)
- torch.manual_seed(SEED)
- y2, p2 = cust(x, paddings)
- assert_close(y1, y2, atol=0, rtol=0)
- assert_close(p1, p2, atol=0, rtol=0)
+ train = mode == 'train'
+ (y1, _), _ = orig_model.apply(
+ initial_params_original,
+ *x,
+ train=train,
+ rngs={'dropout': dropout_rng}, mutable=['batch_stats'])
+ (y2, _), _ = cust_model.apply(
+ initial_params_custom, *x, train=train, rngs={'dropout': dropout_rng}, mutable=['batch_stats'])
+
+
+ assert jnp.allclose(y1, y2)
if __name__ == '__main__':
- absltest.main()
+ absltest.main()
\ No newline at end of file
diff --git a/tests/dropout_fix/ogbg_pytorch_jax/test_model_equivalence.py b/tests/dropout_fix/ogbg_pytorch_jax/test_model_equivalence.py
deleted file mode 100644
index 3b3feb680..000000000
--- a/tests/dropout_fix/ogbg_pytorch_jax/test_model_equivalence.py
+++ /dev/null
@@ -1,118 +0,0 @@
-"""
-Runs fwd pass with random graphs for OGBG GNN models and compares outputs.
-Run with:
- python3 tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py
-"""
-
-import os
-import random
-
-from absl.testing import absltest
-from absl.testing import parameterized
-from jraph import GraphsTuple
-import numpy as np
-import torch
-from torch.testing import assert_close
-
-from algoperf.workloads.ogbg.ogbg_pytorch.models import GNN as OriginalModel
-from algoperf.workloads.ogbg.ogbg_pytorch.models_dropout import \
- GNN as CustomModel
-
-B, N, E = 8, 20, 40 # graphs, nodes/graph, edges/graph
-NODE_FDIM, EDGE_FDIM = 9, 3 # expected feature dims
-DEVICE = 'cuda'
-
-os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
-torch.backends.cudnn.benchmark = False
-torch.backends.cudnn.deterministic = True
-torch.use_deterministic_algorithms(True)
-SEED = 1996
-
-
-def _rand_graph():
- total_nodes, total_edges = B * N, B * E
- nodes = torch.randn(total_nodes, NODE_FDIM, device=DEVICE)
- edges = torch.randn(total_edges, EDGE_FDIM, device=DEVICE)
- senders, receivers = [], []
- for i in range(B):
- offset = i * N
- s = torch.randint(N, (E,), device=DEVICE) + offset
- r = torch.randint(N, (E,), device=DEVICE) + offset
- senders.append(s), receivers.append(r)
- senders = torch.cat(senders)
- receivers = torch.cat(receivers)
- n_node = torch.full((B,), N, device=DEVICE, dtype=torch.int32)
- n_edge = torch.full((B,), E, device=DEVICE, dtype=torch.int32)
- return GraphsTuple(nodes, edges, receivers, senders, None, n_node, n_edge)
-
-
-class GNNEquivalenceTest(parameterized.TestCase):
-
- @parameterized.named_parameters(
- dict(testcase_name='0.0', dropout_rate=0.0),
- dict(testcase_name='0.2', dropout_rate=0.2),
- dict(testcase_name='0.7', dropout_rate=0.7),
- dict(testcase_name='1.0', dropout_rate=1.0),
- )
- def test_forward(self, dropout_rate):
- """Test different dropout_rates."""
-
- orig = OriginalModel(dropout_rate=dropout_rate).to(DEVICE)
- cust = CustomModel().to(DEVICE)
- orig.load_state_dict(cust.state_dict()) # sync weights
-
- graph = _rand_graph()
-
- for mode in ('train', 'eval'):
- getattr(orig, mode)()
- getattr(cust, mode)()
-
- torch.manual_seed(SEED)
- random.seed(SEED)
- np.random.seed(SEED)
- y1 = orig(graph)
-
- torch.manual_seed(SEED)
- random.seed(SEED)
- np.random.seed(SEED)
- y2 = cust(graph, dropout_rate=dropout_rate)
-
- assert_close(y1, y2, atol=0, rtol=0)
-
- if mode == 'eval': # one extra test: omit dropout at eval
- torch.manual_seed(SEED)
- random.seed(SEED)
- np.random.seed(SEED)
- y2 = cust(graph)
- assert_close(y1, y2, atol=0, rtol=0)
-
- @parameterized.named_parameters(
- dict(testcase_name=''),)
- def test_default_dropout(self):
- """Test default dropout_rate."""
-
- orig = OriginalModel().to(DEVICE)
- cust = CustomModel().to(DEVICE)
- orig.load_state_dict(cust.state_dict()) # sync weights
-
- graph = _rand_graph()
-
- for mode in ('train', 'eval'):
- getattr(orig, mode)()
- getattr(cust, mode)()
-
- torch.manual_seed(SEED)
- random.seed(SEED)
- np.random.seed(SEED)
- y1 = orig(graph)
-
- torch.manual_seed(SEED)
- random.seed(SEED)
- np.random.seed(SEED)
- y2 = cust(graph)
-
- assert_close(y1, y2, atol=0, rtol=0)
-
-
-if __name__ == '__main__':
- absltest.main()
diff --git a/tests/dropout_fix/wmt_pytorch_jax/test_model_equivalence.py b/tests/dropout_fix/wmt_pytorch_jax/test_model_equivalence.py
deleted file mode 100644
index 03f289a68..000000000
--- a/tests/dropout_fix/wmt_pytorch_jax/test_model_equivalence.py
+++ /dev/null
@@ -1,121 +0,0 @@
-"""
-Runs fwd pass with random input for WMT Transformer models and compares outputs.
-Run with:
- python3 tests/dropout_fix/wmt_pytorch/test_model_equivalence.py
-"""
-
-import os
-import random
-
-from absl.testing import absltest
-from absl.testing import parameterized
-import numpy as np
-import torch
-from torch.testing import assert_close
-
-from algoperf.workloads.wmt.wmt_pytorch.models import \
- Transformer as OriginalModel
-from algoperf.workloads.wmt.wmt_pytorch.models_dropout import \
- Transformer as CustomModel
-
-B, SRC_LEN, TGT_LEN, NTOK = 16, 80, 80, 32_000
-DEVICE = "cuda"
-SEED = 1996
-
-os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
-torch.backends.cudnn.benchmark = False
-torch.backends.cudnn.deterministic = True
-torch.use_deterministic_algorithms(True)
-
-
-def _rand_tokens(bs, seqlen):
- return torch.randint(1, NTOK, (bs, seqlen), device=DEVICE)
-
-
-class TransformerEquivalenceTest(parameterized.TestCase):
-
- @parameterized.named_parameters(
- # NOTE: removed dropout=1.0 since it will generate nan in scaled_dot_product_attention
- dict(testcase_name="0.0", dropout_rate=0.0, compile=False),
- dict(testcase_name="0.2", dropout_rate=0.2, compile=False),
- dict(testcase_name="0.7", dropout_rate=0.7, compile=False),
- dict(testcase_name="p=0.0_compile", dropout_rate=0.0, compile=True),
- dict(testcase_name="p=0.2_compile", dropout_rate=0.2, compile=True),
- dict(testcase_name="p=0.7_compile", dropout_rate=0.7, compile=True),
- )
- def test_dropout_value(self, dropout_rate, compile):
-
- orig = OriginalModel(
- dropout_rate=dropout_rate,
- attention_dropout_rate=dropout_rate).to(DEVICE)
- cust = CustomModel().to(DEVICE)
-
- orig.load_state_dict(cust.state_dict()) # sync weights
-
- if compile:
- orig = torch.compile(orig)
- cust = torch.compile(cust)
-
- src = _rand_tokens(B, SRC_LEN)
- tgt = _rand_tokens(B, TGT_LEN)
-
- for mode in ("train", "eval"):
- getattr(orig, mode)()
- getattr(cust, mode)()
-
- torch.manual_seed(SEED)
- random.seed(SEED)
- np.random.seed(SEED)
- y1 = orig(src, tgt)
-
- torch.manual_seed(SEED)
- random.seed(SEED)
- np.random.seed(SEED)
- y2 = cust(src, tgt, dropout_rate=dropout_rate)
-
- assert_close(y1, y2, atol=0, rtol=0)
-
- if mode == 'eval': # one extra test: omit dropout at eval
- torch.manual_seed(SEED)
- random.seed(SEED)
- np.random.seed(SEED)
- y2 = cust(src, tgt)
- assert_close(y1, y2, atol=0, rtol=0)
-
- @parameterized.named_parameters(
- dict(testcase_name="default", compile=False),
- dict(testcase_name="default_compile", compile=True),
- )
- def test_default(self, compile):
-
- orig = OriginalModel().to(DEVICE)
- cust = CustomModel().to(DEVICE)
-
- orig.load_state_dict(cust.state_dict()) # sync weights
-
- if compile:
- orig = torch.compile(orig)
- cust = torch.compile(cust)
-
- src = _rand_tokens(B, SRC_LEN)
- tgt = _rand_tokens(B, TGT_LEN)
-
- for mode in ("train", "eval"):
- getattr(orig, mode)()
- getattr(cust, mode)()
-
- torch.manual_seed(SEED)
- random.seed(SEED)
- np.random.seed(SEED)
- y1 = orig(src, tgt)
-
- torch.manual_seed(SEED)
- random.seed(SEED)
- np.random.seed(SEED)
- y2 = cust(src, tgt)
-
- assert_close(y1, y2, atol=0, rtol=0)
-
-
-if __name__ == "__main__":
- absltest.main()
From caacb84cb4c59226e2031a5c8d2a715242830662 Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Sat, 21 Jun 2025 08:05:59 +0000
Subject: [PATCH 083/123] add tests
---
.../ogbg_jax/test_model_equivalence.py | 133 ++++++++++++++++++
.../wmt_jax/test_model_equivalence.py | 129 +++++++++++++++++
2 files changed, 262 insertions(+)
create mode 100644 tests/dropout_fix/ogbg_jax/test_model_equivalence.py
create mode 100644 tests/dropout_fix/wmt_jax/test_model_equivalence.py
diff --git a/tests/dropout_fix/ogbg_jax/test_model_equivalence.py b/tests/dropout_fix/ogbg_jax/test_model_equivalence.py
new file mode 100644
index 000000000..1b5b8a180
--- /dev/null
+++ b/tests/dropout_fix/ogbg_jax/test_model_equivalence.py
@@ -0,0 +1,133 @@
+"""
+Runs fwd pass with random input for OGBG
+"""
+
+import os
+
+import jraph
+
+from absl.testing import absltest
+from absl.testing import parameterized
+import jax
+import jax.numpy as jnp
+
+from algoperf.workloads.ogbg.ogbg_jax.models_ref import \
+ GNN as OrigCls
+from algoperf.workloads.ogbg.ogbg_jax.models import \
+ GNN as CustCls
+
+# Model / test hyper-params
+SEED = 1994
+
+class ModeEquivalenceTest(parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ dict(
+ testcase_name='OGBG, p=0.0',
+ dropout_rate=0.0),
+ dict(
+ testcase_name='OGBG, p=0.1',
+ dropout_rate=0.1),
+ )
+ def test_forward(self, dropout_rate):
+ # init model
+ rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3)
+
+ orig_model = OrigCls(num_outputs=128, dropout_rate=dropout_rate)
+ cust_model = CustCls(num_outputs=128)
+
+ fake_batch = jraph.GraphsTuple(
+ n_node=jnp.asarray([1]),
+ n_edge=jnp.asarray([1]),
+ nodes=jnp.ones((1, 9)),
+ edges=jnp.ones((1, 3)),
+ globals=jnp.zeros((1, 128)),
+ senders=jnp.asarray([0]),
+ receivers=jnp.asarray([0]))
+
+ initial_params_original = orig_model.init({'params': rng},
+ fake_batch,
+ train=False)
+ initial_params_custom = cust_model.init({'params': rng},
+ fake_batch,
+ train=False)
+
+ # fwd
+ x = jraph.GraphsTuple(
+ n_node=jnp.asarray([1]),
+ n_edge=jnp.asarray([1]),
+ nodes=jnp.ones((1, 9)),
+ edges=jnp.ones((1, 3)),
+ globals=jnp.zeros((1, 128)),
+ senders=jnp.asarray([0]),
+ receivers=jnp.asarray([0]))
+
+ for mode in ('train', 'eval'):
+ train = mode == 'train'
+ y1 = orig_model.apply(
+ initial_params_original,
+ x,
+ train=train,
+ rngs={'dropout': dropout_rng})
+ y2 = cust_model.apply(
+ initial_params_custom,
+ x,
+ train=train,
+ dropout_rate=dropout_rate,
+ rngs={'dropout': dropout_rng})
+
+ assert jnp.allclose(y1, y2, atol=1e-3, rtol=1e-3)
+
+ @parameterized.named_parameters(
+ dict(testcase_name='OGBG, default'),
+ )
+ def test_default_dropout(self):
+ """Test default dropout_rate."""
+
+
+ # init model
+ rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3)
+
+ orig_model = OrigCls(num_outputs=128)
+ cust_model = CustCls(num_outputs=128)
+
+ fake_batch = jraph.GraphsTuple(
+ n_node=jnp.asarray([1]),
+ n_edge=jnp.asarray([1]),
+ nodes=jnp.ones((1, 9)),
+ edges=jnp.ones((1, 3)),
+ globals=jnp.zeros((1, 128)),
+ senders=jnp.asarray([0]),
+ receivers=jnp.asarray([0]))
+
+ initial_params_original = orig_model.init({'params': rng},
+ fake_batch,
+ train=False)
+ initial_params_custom = cust_model.init({'params': rng},
+ fake_batch,
+ train=False)
+
+ # fwd
+ x = jraph.GraphsTuple(
+ n_node=jnp.asarray([1]),
+ n_edge=jnp.asarray([1]),
+ nodes=jnp.ones((1, 9)),
+ edges=jnp.ones((1, 3)),
+ globals=jnp.zeros((1, 128)),
+ senders=jnp.asarray([0]),
+ receivers=jnp.asarray([0]))
+
+ for mode in ('train', 'eval'):
+ train = mode == 'train'
+ y1 = orig_model.apply(
+ initial_params_original,
+ x,
+ train=train,
+ rngs={'dropout': dropout_rng})
+ y2 = cust_model.apply(
+ initial_params_custom, x, train=train, rngs={'dropout': dropout_rng})
+
+ assert jnp.allclose(y1, y2, atol=0, rtol=0)
+
+if __name__ == '__main__':
+ absltest.main()
\ No newline at end of file
diff --git a/tests/dropout_fix/wmt_jax/test_model_equivalence.py b/tests/dropout_fix/wmt_jax/test_model_equivalence.py
new file mode 100644
index 000000000..c0443bb50
--- /dev/null
+++ b/tests/dropout_fix/wmt_jax/test_model_equivalence.py
@@ -0,0 +1,129 @@
+"""
+Runs fwd pass with random input for FASTMRI U-Net models and compares outputs.
+Run it as:
+ python3 tests/dropout_fix/imagenet_vit_jax/test_model_equivalence.py
+"""
+
+import os
+
+from absl.testing import absltest
+from absl.testing import parameterized
+import jax
+import jax.numpy as jnp
+
+from algoperf.workloads.wmt.wmt_jax.models import TransformerConfig as CustClsConfig
+from algoperf.workloads.wmt.wmt_jax.models import Transformer as CustCls
+
+from algoperf.workloads.wmt.wmt_jax.models_ref import TransformerConfig as OrigClsConfig
+from algoperf.workloads.wmt.wmt_jax.models_ref import Transformer as OrigCls
+
+
+# Model / test hyper-params
+SEED = 1994
+
+class ModeEquivalenceTest(parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ dict(
+ testcase_name='WMT, p=0.0',
+ dropout_rate=0.0),
+ dict(
+ testcase_name='WMT p=0.1',
+ dropout_rate=0.1),
+ )
+ def test_forward(self, dropout_rate):
+
+ # init model
+ rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3)
+
+ orig_model = OrigCls(OrigClsConfig)
+ cust_model = CustCls(CustClsConfig)
+
+ init_fake_batch_size = 8
+ input_shape = (init_fake_batch_size, 256)
+ target_shape = (init_fake_batch_size, 256)
+
+ initial_params_original = orig_model.init({'params': rng},
+ jnp.ones(input_shape, jnp.float32),
+ jnp.ones(target_shape, jnp.float32),
+ train=False)
+ initial_params_custom = cust_model.init({'params': rng},
+ jnp.ones(input_shape, jnp.float32),
+ jnp.ones(target_shape, jnp.float32),
+ train=False)
+
+ # fwd
+
+ for mode in ('train', 'eval'):
+ train = mode == 'train'
+ y1 = orig_model.apply(
+ initial_params_original,
+ jnp.ones(input_shape, jnp.float32),
+ jnp.ones(target_shape, jnp.float32),
+ train=train,
+ rngs={'dropout': dropout_rng},
+ mutable=['batch_stats'],)
+ y2 = cust_model.apply(
+ initial_params_custom,
+ jnp.ones(input_shape, jnp.float32),
+ jnp.ones(target_shape, jnp.float32),
+ train=train,
+ dropout_rate=dropout_rate,
+ rngs={'dropout': dropout_rng},
+ mutable=['batch_stats'])
+
+ for i in range(len(y1)):
+ assert jnp.allclose(y1[i], y2[i])
+
+
+
+ @parameterized.named_parameters(
+ dict(testcase_name='WMT, default'),
+ )
+ def test_default_dropout(self):
+ """Test default dropout_rate."""
+ # init model
+ rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3)
+
+ orig_model = OrigCls(OrigClsConfig)
+ cust_model = CustCls(CustClsConfig)
+
+ init_fake_batch_size = 8
+ input_shape = (init_fake_batch_size, 256)
+ target_shape = (init_fake_batch_size, 256)
+
+ initial_params_original = orig_model.init({'params': rng},
+ jnp.ones(input_shape, jnp.float32),
+ jnp.ones(target_shape, jnp.float32),
+ train=False)
+ initial_params_custom = cust_model.init({'params': rng},
+ jnp.ones(input_shape, jnp.float32),
+ jnp.ones(target_shape, jnp.float32),
+ train=False)
+
+ # fwd
+ x = [jax.random.normal(data_rng, (2, *x)) for x in INPUT_SHAPE]
+
+ for mode in ('train', 'eval'):
+ train = mode == 'train'
+ y1 = orig_model.apply(
+ initial_params_original,
+ jnp.ones(input_shape, jnp.float32),
+ jnp.ones(target_shape, jnp.float32),
+ train=train,
+ rngs={'dropout': dropout_rng}, mutable=['batch_stats'])
+ y2 = cust_model.apply(
+ initial_params_custom,
+ jnp.ones(input_shape, jnp.float32),
+ jnp.ones(target_shape, jnp.float32),
+ train=train, rngs={'dropout': dropout_rng},
+ mutable=['batch_stats'])
+
+ print(jax.tree.map(lambda x: x.shape, y1))
+
+ for i in range(len(y1)):
+ assert jnp.allclose(y1[i], y2[i])
+
+
+if __name__ == '__main__':
+ absltest.main()
\ No newline at end of file
From 8b0a12529870be76cd216835f943462424ec5131 Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Sat, 21 Jun 2025 08:28:09 +0000
Subject: [PATCH 084/123] fix wmt test
---
.../wmt_jax/test_model_equivalence.py | 104 ++++++++----------
1 file changed, 44 insertions(+), 60 deletions(-)
diff --git a/tests/dropout_fix/wmt_jax/test_model_equivalence.py b/tests/dropout_fix/wmt_jax/test_model_equivalence.py
index c0443bb50..d420bf691 100644
--- a/tests/dropout_fix/wmt_jax/test_model_equivalence.py
+++ b/tests/dropout_fix/wmt_jax/test_model_equivalence.py
@@ -26,18 +26,20 @@ class ModeEquivalenceTest(parameterized.TestCase):
@parameterized.named_parameters(
dict(
testcase_name='WMT, p=0.0',
- dropout_rate=0.0),
+ dropout_rate=0.0,
+ train=True),
dict(
testcase_name='WMT p=0.1',
- dropout_rate=0.1),
+ dropout_rate=0.1,
+ train=False),
)
- def test_forward(self, dropout_rate):
+ def test_forward(self, dropout_rate, train):
# init model
rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3)
- orig_model = OrigCls(OrigClsConfig)
- cust_model = CustCls(CustClsConfig)
+ orig_model = OrigCls(OrigClsConfig(deterministic=not train, attention_dropout_rate=dropout_rate, dropout_rate=dropout_rate))
+ cust_model = CustCls(CustClsConfig(deterministic=not train))
init_fake_batch_size = 8
input_shape = (init_fake_batch_size, 256)
@@ -45,48 +47,39 @@ def test_forward(self, dropout_rate):
initial_params_original = orig_model.init({'params': rng},
jnp.ones(input_shape, jnp.float32),
- jnp.ones(target_shape, jnp.float32),
- train=False)
+ jnp.ones(target_shape, jnp.float32))
initial_params_custom = cust_model.init({'params': rng},
jnp.ones(input_shape, jnp.float32),
- jnp.ones(target_shape, jnp.float32),
- train=False)
+ jnp.ones(target_shape, jnp.float32),)
# fwd
- for mode in ('train', 'eval'):
- train = mode == 'train'
- y1 = orig_model.apply(
- initial_params_original,
- jnp.ones(input_shape, jnp.float32),
- jnp.ones(target_shape, jnp.float32),
- train=train,
- rngs={'dropout': dropout_rng},
- mutable=['batch_stats'],)
- y2 = cust_model.apply(
- initial_params_custom,
- jnp.ones(input_shape, jnp.float32),
- jnp.ones(target_shape, jnp.float32),
- train=train,
- dropout_rate=dropout_rate,
- rngs={'dropout': dropout_rng},
- mutable=['batch_stats'])
-
- for i in range(len(y1)):
- assert jnp.allclose(y1[i], y2[i])
+ y1 = orig_model.apply(
+ initial_params_original,
+ jnp.ones(input_shape, jnp.float32),
+ jnp.ones(target_shape, jnp.float32),
+ rngs={'dropout': dropout_rng})
+
+ y2 = cust_model.apply(
+ initial_params_custom,
+ jnp.ones(input_shape, jnp.float32),
+ jnp.ones(target_shape, jnp.float32),
+ dropout_rate=dropout_rate,
+ rngs={'dropout': dropout_rng})
+
+ assert jnp.allclose(y1, y2)
@parameterized.named_parameters(
- dict(testcase_name='WMT, default'),
+ dict(testcase_name='WMT, default train', train=True),
+ dict(testcase_name='WMT, default eval', train=False),
)
- def test_default_dropout(self):
+ def test_default_dropout(self, train):
"""Test default dropout_rate."""
- # init model
rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3)
-
- orig_model = OrigCls(OrigClsConfig)
- cust_model = CustCls(CustClsConfig)
+ orig_model = OrigCls(OrigClsConfig(deterministic=not train))
+ cust_model = CustCls(CustClsConfig(deterministic=not train))
init_fake_batch_size = 8
input_shape = (init_fake_batch_size, 256)
@@ -94,35 +87,26 @@ def test_default_dropout(self):
initial_params_original = orig_model.init({'params': rng},
jnp.ones(input_shape, jnp.float32),
- jnp.ones(target_shape, jnp.float32),
- train=False)
+ jnp.ones(target_shape, jnp.float32))
initial_params_custom = cust_model.init({'params': rng},
jnp.ones(input_shape, jnp.float32),
- jnp.ones(target_shape, jnp.float32),
- train=False)
+ jnp.ones(target_shape, jnp.float32))
# fwd
- x = [jax.random.normal(data_rng, (2, *x)) for x in INPUT_SHAPE]
-
- for mode in ('train', 'eval'):
- train = mode == 'train'
- y1 = orig_model.apply(
- initial_params_original,
- jnp.ones(input_shape, jnp.float32),
- jnp.ones(target_shape, jnp.float32),
- train=train,
- rngs={'dropout': dropout_rng}, mutable=['batch_stats'])
- y2 = cust_model.apply(
- initial_params_custom,
- jnp.ones(input_shape, jnp.float32),
- jnp.ones(target_shape, jnp.float32),
- train=train, rngs={'dropout': dropout_rng},
- mutable=['batch_stats'])
-
- print(jax.tree.map(lambda x: x.shape, y1))
-
- for i in range(len(y1)):
- assert jnp.allclose(y1[i], y2[i])
+
+ y1 = orig_model.apply(
+ initial_params_original,
+ jnp.ones(input_shape, jnp.float32),
+ jnp.ones(target_shape, jnp.float32),
+ rngs={'dropout': dropout_rng})
+
+ y2 = cust_model.apply(
+ initial_params_custom,
+ jnp.ones(input_shape, jnp.float32),
+ jnp.ones(target_shape, jnp.float32),
+ rngs={'dropout': dropout_rng})
+
+ assert jnp.allclose(y1, y2)
if __name__ == '__main__':
From 6c7d69590f58c55e928165c958e51a86af697928 Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Tue, 24 Jun 2025 00:00:40 +0000
Subject: [PATCH 085/123] remove dropout fix tests
---
.../criteo1tb_jax/test_model_equivalence.py | 156 ------------------
.../test_model_equivalence.py | 120 --------------
.../fastmri_jax/test_model_equivalence.py | 118 -------------
.../fastmri_pytorch/test_model_equivalence.py | 130 ---------------
.../test_model_equivalence.py | 110 ------------
.../test_model_equivalence.py | 138 ----------------
.../test_model_equivalence.py | 117 -------------
.../test_model_equivalence.py | 84 ----------
.../test_model_equivalence.py | 113 -------------
.../test_model_equivalence.py | 124 --------------
.../ogbg_jax/test_model_equivalence.py | 133 ---------------
.../ogbg_pytorch/test_model_equivalence.py | 118 -------------
.../wmt_jax/test_model_equivalence.py | 113 -------------
.../wmt_pytorch/test_model_equivalence.py | 121 --------------
14 files changed, 1695 deletions(-)
delete mode 100644 tests/dropout_fix/criteo1tb_jax/test_model_equivalence.py
delete mode 100644 tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py
delete mode 100644 tests/dropout_fix/fastmri_jax/test_model_equivalence.py
delete mode 100644 tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py
delete mode 100644 tests/dropout_fix/imagenet_vit_jax/test_model_equivalence.py
delete mode 100644 tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py
delete mode 100644 tests/dropout_fix/librispeech_conformer_jax/test_model_equivalence.py
delete mode 100644 tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py
delete mode 100644 tests/dropout_fix/librispeech_deepspeech_jax/test_model_equivalence.py
delete mode 100644 tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py
delete mode 100644 tests/dropout_fix/ogbg_jax/test_model_equivalence.py
delete mode 100644 tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py
delete mode 100644 tests/dropout_fix/wmt_jax/test_model_equivalence.py
delete mode 100644 tests/dropout_fix/wmt_pytorch/test_model_equivalence.py
diff --git a/tests/dropout_fix/criteo1tb_jax/test_model_equivalence.py b/tests/dropout_fix/criteo1tb_jax/test_model_equivalence.py
deleted file mode 100644
index 10aeaa650..000000000
--- a/tests/dropout_fix/criteo1tb_jax/test_model_equivalence.py
+++ /dev/null
@@ -1,156 +0,0 @@
-"""
-Runs fwd pass with random input for our DLRM models and compares outputs.
-Run it as:
- python3 tests/dropout_fix/criteo1tb_jax/test_model_equivalence.py
-"""
-
-import os
-
-from absl.testing import absltest
-from absl.testing import parameterized
-import jax
-import jax.numpy as jnp
-# import equinox as eqx
-
-from jax.tree_util import tree_structure, tree_leaves, tree_map
-
-
-def pytrees_are_equal(a, b, rtol=1e-5, atol=1e-8):
- """
- A custom function to check if two PyTrees are equal, handling floats with a tolerance.
- """
- # 1. Check if the structures are the same
- if tree_structure(a) != tree_structure(b):
- return False
-
- # 2. Define a comparison function for leaves
- def leaf_comparator(x, y):
- # Use allclose for floating-point JAX arrays
- if isinstance(x, jnp.ndarray) and jnp.issubdtype(x.dtype, jnp.floating):
- return jnp.allclose(x, y, rtol=rtol, atol=atol)
- # Use standard equality for everything else
- else:
- return x == y
-
- # 3. Map the comparison function over the trees and check if all results are True
- # We also need to flatten the results of the tree_map and check if all are True
- comparison_tree = tree_map(leaf_comparator, a, b)
- all_equal = all(tree_leaves(comparison_tree))
-
- return all_equal
-
-from algoperf.workloads.criteo1tb.criteo1tb_jax.models_ref import \
- DLRMResNet as OriginalDLRMResNet
-from algoperf.workloads.criteo1tb.criteo1tb_jax.models_ref import \
- DlrmSmall as OriginalDlrmSmall
-from algoperf.workloads.criteo1tb.criteo1tb_jax.models import \
- DLRMResNet as CustomDLRMResNet
-from algoperf.workloads.criteo1tb.criteo1tb_jax.models import \
- DlrmSmall as CustomDlrmSmall
-
-BATCH, DENSE, SPARSE = 16, 13, 26
-FEATURES = DENSE + SPARSE
-VOCAB = 1000
-DEVICE = 'cuda'
-SEED = 1996
-
-
-class ModelEquivalenceTest(parameterized.TestCase):
-
- @parameterized.named_parameters(
- dict(
- testcase_name='DLRMResNet, p=0.0',
- model='dlrm_resnet',
- dropout_rate=0.0),
- dict(
- testcase_name='DlrmSmall, p=0.0',
- model='dlrm_small',
- dropout_rate=0.0),
- dict(
- testcase_name='DLRMResNet, p=0.1',
- model='dlrm_resnet',
- dropout_rate=0.1),
- dict(
- testcase_name='DlrmSmall, p=0.1',
- model='dlrm_small',
- dropout_rate=0.1),
- )
- def test_forward(self, model, dropout_rate):
- OrigCls, CustCls = (
- (OriginalDLRMResNet, CustomDLRMResNet)
- if model == 'dlrm_resnet'
- else (OriginalDlrmSmall, CustomDlrmSmall)
- )
-
- # init model
- rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3)
- fake_batch = jnp.ones((2, 39))
- assert dropout_rate == 0.1
- orig_model = OrigCls(vocab_size=VOCAB, dropout_rate=dropout_rate)
- cust_model = CustCls(vocab_size=VOCAB)
- initial_params_original = orig_model.init({'params': rng},
- fake_batch,
- train=False)
- initial_params_custom = cust_model.init({'params': rng},
- fake_batch,
- train=False)
- assert pytrees_are_equal(
- initial_params_original, initial_params_custom, rtol=1e-6)
-
- x = jax.random.normal(data_rng, shape=(BATCH, FEATURES))
-
- for mode in ('train', 'eval'):
- train = mode == 'train'
- y1 = orig_model.apply(
- initial_params_original,
- x,
- train=train,
- rngs={'dropout': dropout_rng})
- y2 = cust_model.apply(
- initial_params_custom,
- x,
- train=train,
- dropout_rate=dropout_rate,
- rngs={'dropout': dropout_rng})
-
- assert jnp.allclose(y1, y2, atol=1e-3, rtol=1e-3)
-
- @parameterized.named_parameters(
- dict(testcase_name='DLRMResNet, default', model='dlrm_resnet'),
- dict(testcase_name='DlrmSmall, default', model='dlrm_small'),
- )
- def test_default_dropout(self, model):
- """Test default dropout_rate."""
- OrigCls, CustCls = (
- (OriginalDLRMResNet, CustomDLRMResNet)
- if model == 'dlrm_resnet'
- else (OriginalDlrmSmall, CustomDlrmSmall)
- )
- rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3)
- fake_batch = jnp.ones((2, 39))
- orig_model = OrigCls(vocab_size=VOCAB)
- cust_model = CustCls(vocab_size=VOCAB)
- initial_params_original = orig_model.init({'params': rng},
- fake_batch,
- train=False)
- initial_params_custom = cust_model.init({'params': rng},
- fake_batch,
- train=False)
-
- x = jax.random.normal(data_rng, shape=(BATCH, FEATURES))
-
- for mode in ('train', 'eval'):
- train = mode == 'train'
- y1 = orig_model.apply(
- initial_params_original,
- x,
- train=train,
- rngs={'dropout': dropout_rng})
- y2 = cust_model.apply(
- initial_params_custom, x, train=train, rngs={'dropout': dropout_rng})
-
- assert jnp.allclose(y1, y2, atol=0, rtol=0)
-
-
-if __name__ == '__main__':
- absltest.main()
diff --git a/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py b/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py
deleted file mode 100644
index 733052dd0..000000000
--- a/tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py
+++ /dev/null
@@ -1,120 +0,0 @@
-"""
-Runs fwd pass with random input for our DLRM models and compares outputs.
-Run it as:
- python3 tests/dropout_fix/criteo1tb_pytorch/test_model_equivalence.py
-"""
-
-import os
-
-from absl.testing import absltest
-from absl.testing import parameterized
-import torch
-from torch.testing import assert_close
-
-from algoperf.workloads.criteo1tb.criteo1tb_pytorch.models import \
- DLRMResNet as OriginalDLRMResNet
-from algoperf.workloads.criteo1tb.criteo1tb_pytorch.models import \
- DlrmSmall as OriginalDlrmSmall
-from algoperf.workloads.criteo1tb.criteo1tb_pytorch.models_dropout import \
- DLRMResNet as CustomDLRMResNet
-from algoperf.workloads.criteo1tb.criteo1tb_pytorch.models_dropout import \
- DlrmSmall as CustomDlrmSmall
-
-BATCH, DENSE, SPARSE = 16, 13, 26
-FEATURES = DENSE + SPARSE
-VOCAB = 1000
-DEVICE = 'cuda'
-SEED = 1996
-
-os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
-torch.backends.cudnn.benchmark = False
-torch.backends.cudnn.deterministic = True
-torch.use_deterministic_algorithms(True)
-
-
-class ModelEquivalenceTest(parameterized.TestCase):
-
- @parameterized.named_parameters(
- dict(
- testcase_name='DLRMResNet, p=0.0',
- model='dlrm_resnet',
- dropout_rate=0.0),
- dict(
- testcase_name='DlrmSmall, p=0.0',
- model='dlrm_small',
- dropout_rate=0.0),
- dict(
- testcase_name='DLRMResNet, p=0.1',
- model='dlrm_resnet',
- dropout_rate=0.1),
- dict(
- testcase_name='DlrmSmall, p=0.1',
- model='dlrm_small',
- dropout_rate=0.1),
- dict(
- testcase_name='DLRMResNet, p=1.0',
- model='dlrm_resnet',
- dropout_rate=1.0),
- dict(
- testcase_name='DlrmSmall, p=1.0',
- model='dlrm_small',
- dropout_rate=1.0),
- )
- def test_forward(self, model, dropout_rate):
- OrigCls, CustCls = (
- (OriginalDLRMResNet, CustomDLRMResNet)
- if model == 'dlrm_resnet'
- else (OriginalDlrmSmall, CustomDlrmSmall)
- )
-
- torch.manual_seed(SEED)
- orig = OrigCls(vocab_size=VOCAB, dropout_rate=dropout_rate).to(DEVICE)
-
- torch.manual_seed(SEED)
- cust = CustCls(vocab_size=VOCAB).to(DEVICE)
-
- x = torch.randn(BATCH, FEATURES, device=DEVICE)
-
- for mode in ('train', 'eval'):
- getattr(orig, mode)()
- getattr(cust, mode)()
- torch.manual_seed(SEED)
- y1 = orig(x)
- torch.manual_seed(SEED)
- y2 = cust(x, dropout_rate)
- assert_close(y1, y2, atol=0, rtol=0)
- if mode == 'eval': # one extra test: omit dropout at eval
- torch.manual_seed(SEED)
- y2 = cust(x)
- assert_close(y1, y2, atol=0, rtol=0)
-
- @parameterized.named_parameters(
- dict(testcase_name='DLRMResNet, default', model='dlrm_resnet'),
- dict(testcase_name='DlrmSmall, default', model='dlrm_small'),
- )
- def test_default_dropout(self, model):
- """Test default dropout_rate."""
- OrigCls, CustCls = (
- (OriginalDLRMResNet, CustomDLRMResNet)
- if model == 'dlrm_resnet'
- else (OriginalDlrmSmall, CustomDlrmSmall)
- )
-
- torch.manual_seed(SEED)
- orig = OrigCls(vocab_size=VOCAB).to(DEVICE)
- torch.manual_seed(SEED)
- cust = CustCls(vocab_size=VOCAB).to(DEVICE)
-
- x = torch.randn(BATCH, FEATURES, device=DEVICE)
- for mode in ('train', 'eval'):
- getattr(orig, mode)()
- getattr(cust, mode)()
- torch.manual_seed(0)
- y1 = orig(x)
- torch.manual_seed(0)
- y2 = cust(x)
- assert_close(y1, y2, atol=0, rtol=0)
-
-
-if __name__ == '__main__':
- absltest.main()
diff --git a/tests/dropout_fix/fastmri_jax/test_model_equivalence.py b/tests/dropout_fix/fastmri_jax/test_model_equivalence.py
deleted file mode 100644
index fbb6d2499..000000000
--- a/tests/dropout_fix/fastmri_jax/test_model_equivalence.py
+++ /dev/null
@@ -1,118 +0,0 @@
-"""
-Runs fwd pass with random input for FASTMRI U-Net models and compares outputs.
-Run it as:
- python3 tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py
-"""
-
-import os
-
-
-from absl.testing import absltest
-from absl.testing import parameterized
-import jax
-import jax.numpy as jnp
-# import equinox as eqx
-
-
-from algoperf.workloads.fastmri.fastmri_jax.models_ref import \
- UNet as OriginalUNet
-from algoperf.workloads.fastmri.fastmri_jax.models import \
- UNet as CustomUNet
-
-BATCH, IN_CHANS, H, W = 4, 1, 256, 256
-OUT_CHANS, C, LAYERS = 1, 32, 4
-SEED = 1996
-
-
-class ModelEquivalenceTest(parameterized.TestCase):
-
- @parameterized.named_parameters(
- dict(
- testcase_name='UNet, p=0.0',
- dropout_rate=0.0),
- dict(
- testcase_name='UNet, p=0.1',
- dropout_rate=0.1),
- )
- def test_forward(self, dropout_rate):
- OrigCls, CustCls = (OriginalUNet, CustomUNet)
-
-
- # init model
- rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3)
-
- kwargs = dict(num_pool_layers = LAYERS, num_channels=IN_CHANS)
- orig_model = OrigCls(**kwargs)
- cust_model = CustCls(**kwargs)
-
- fake_batch = jnp.ones((BATCH, IN_CHANS, H, W))
-
- initial_params_original = orig_model.init({'params': rng},
- fake_batch,
- train=False)
- initial_params_custom = cust_model.init({'params': rng},
- fake_batch,
- train=False)
-
- # fwd
- x = jax.random.normal(data_rng, shape=(BATCH, H, W))
-
- for mode in ('train', 'eval'):
- train = mode == 'train'
- y1 = orig_model.apply(
- initial_params_original,
- x,
- train=train,
- rngs={'dropout': dropout_rng})
- y2 = cust_model.apply(
- initial_params_custom,
- x,
- train=train,
- dropout_rate=dropout_rate,
- rngs={'dropout': dropout_rng})
-
- assert jnp.allclose(y1, y2, atol=1e-3, rtol=1e-3)
-
- @parameterized.named_parameters(
- dict(testcase_name='UNet, default'),
- )
- def test_default_dropout(self):
- """Test default dropout_rate."""
- OrigCls, CustCls = (OriginalUNet, CustomUNet)
-
-
- # init model
- rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3)
-
- kwargs = dict(num_pool_layers=LAYERS,
- num_channels=IN_CHANS,
- )
- orig_model = OrigCls(**kwargs)
- cust_model = CustCls(**kwargs)
-
- fake_batch = jnp.ones((2, IN_CHANS, H, W))
-
- initial_params_original = orig_model.init({'params': rng},
- fake_batch,
- train=False)
- initial_params_custom = cust_model.init({'params': rng},
- fake_batch,
- train=False)
-
- # fwd
- x = jax.random.normal(data_rng, shape=(BATCH, H, W))
-
- for mode in ('train', 'eval'):
- train = mode == 'train'
- y1 = orig_model.apply(
- initial_params_original,
- x,
- train=train,
- rngs={'dropout': dropout_rng})
- y2 = cust_model.apply(
- initial_params_custom, x, train=train, rngs={'dropout': dropout_rng})
-
- assert jnp.allclose(y1, y2, atol=0, rtol=0)
-
-if __name__ == '__main__':
- absltest.main()
diff --git a/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py b/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py
deleted file mode 100644
index 1d318e8c6..000000000
--- a/tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py
+++ /dev/null
@@ -1,130 +0,0 @@
-"""
-Runs fwd pass with random input for FASTMRI U-Net models and compares outputs.
-Run it as:
- python3 tests/dropout_fix/fastmri_pytorch/test_model_equivalence.py
-"""
-
-import os
-
-from absl.testing import absltest
-from absl.testing import parameterized
-import torch
-from torch.testing import assert_close
-
-from algoperf.workloads.fastmri.fastmri_pytorch.models import \
- UNet as OriginalUNet
-from algoperf.workloads.fastmri.fastmri_pytorch.models_dropout import \
- UNet as CustomUNet
-
-BATCH, IN_CHANS, H, W = 4, 1, 256, 256
-OUT_CHANS, C, LAYERS = 1, 32, 4
-DEVICE = 'cuda'
-TORCH_COMPILE = False
-SEED = 1996
-
-os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
-torch.backends.cudnn.benchmark = False
-torch.backends.cudnn.deterministic = True
-torch.use_deterministic_algorithms(True)
-
-
-class FastMRIModeEquivalenceTest(parameterized.TestCase):
-
- def fwd_pass(self, orig, cust, dropout_rate):
- x = torch.randn(BATCH, IN_CHANS, H, W, device=DEVICE)
- for mode in ('train', 'eval'):
- getattr(orig, mode)()
- getattr(cust, mode)()
- torch.manual_seed(0)
- y1 = orig(x)
- torch.manual_seed(0)
- y2 = cust(x, dropout_rate)
- assert_close(y1, y2, atol=0, rtol=0)
- if mode == 'eval': # one extra test: omit dropout at eval
- torch.manual_seed(0)
- y2 = cust(x)
- assert_close(y1, y2, atol=0, rtol=0)
-
- @parameterized.named_parameters(
- dict(testcase_name='p=0.0', dropout_rate=0.0),
- dict(testcase_name='p=0.1', dropout_rate=0.1),
- dict(testcase_name='p=0.7', dropout_rate=0.7),
- dict(testcase_name='p=1.0', dropout_rate=1.0),
- )
- def test_dropout_values(self, dropout_rate):
- """Test different values of dropout_rate."""
-
- torch.manual_seed(SEED)
- orig = OriginalUNet(
- IN_CHANS, OUT_CHANS, C, LAYERS, dropout_rate=dropout_rate).to(DEVICE)
-
- torch.manual_seed(SEED)
- cust = CustomUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE)
-
- cust.load_state_dict(orig.state_dict()) # sync weights
- if TORCH_COMPILE:
- orig = torch.compile(orig)
- cust = torch.compile(cust)
-
- self.fwd_pass(orig, cust, dropout_rate)
-
- @parameterized.named_parameters(
- dict(testcase_name='default', use_tanh=False, use_layer_norm=False),
- dict(testcase_name='tanh', use_tanh=True, use_layer_norm=False),
- dict(testcase_name='layer_norm', use_tanh=False, use_layer_norm=True),
- dict(testcase_name='both', use_tanh=True, use_layer_norm=True),
- )
- def test_arch_configs(self, use_tanh, use_layer_norm):
- """Test different architecture configurations, fixed dropout_rate."""
- dropout_rate = 0.1
-
- torch.manual_seed(SEED)
- orig = OriginalUNet(
- IN_CHANS,
- OUT_CHANS,
- C,
- LAYERS,
- dropout_rate=dropout_rate,
- use_tanh=use_tanh,
- use_layer_norm=use_layer_norm).to(DEVICE)
-
- torch.manual_seed(SEED)
- cust = CustomUNet(
- IN_CHANS,
- OUT_CHANS,
- C,
- LAYERS,
- use_tanh=use_tanh,
- use_layer_norm=use_layer_norm).to(DEVICE)
-
- cust.load_state_dict(orig.state_dict()) # sync weights
- if TORCH_COMPILE:
- orig = torch.compile(orig)
- cust = torch.compile(cust)
-
- self.fwd_pass(orig, cust, dropout_rate)
-
- @parameterized.named_parameters(
- dict(testcase_name=''),)
- def test_default_dropout(self):
- """Test default dropout_rate."""
-
- torch.manual_seed(SEED)
- orig = OriginalUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE)
- torch.manual_seed(SEED)
- cust = CustomUNet(IN_CHANS, OUT_CHANS, C, LAYERS).to(DEVICE)
- cust.load_state_dict(orig.state_dict()) # sync weights
-
- x = torch.randn(BATCH, IN_CHANS, H, W, device=DEVICE)
- for mode in ('train', 'eval'):
- getattr(orig, mode)()
- getattr(cust, mode)()
- torch.manual_seed(0)
- y1 = orig(x)
- torch.manual_seed(0)
- y2 = cust(x)
- assert_close(y1, y2, atol=0, rtol=0)
-
-
-if __name__ == '__main__':
- absltest.main()
diff --git a/tests/dropout_fix/imagenet_vit_jax/test_model_equivalence.py b/tests/dropout_fix/imagenet_vit_jax/test_model_equivalence.py
deleted file mode 100644
index 6806ca99f..000000000
--- a/tests/dropout_fix/imagenet_vit_jax/test_model_equivalence.py
+++ /dev/null
@@ -1,110 +0,0 @@
-"""
-Runs fwd pass with random input for FASTMRI U-Net models and compares outputs.
-Run it as:
- python3 tests/dropout_fix/imagenet_vit_jax/test_model_equivalence.py
-"""
-
-import os
-
-from absl.testing import absltest
-from absl.testing import parameterized
-import jax
-import jax.numpy as jnp
-
-from algoperf.workloads.imagenet_vit.imagenet_jax.models_ref import \
- ViT as OriginalVit
-from algoperf.workloads.imagenet_vit.imagenet_jax.models import \
- ViT as CustomVit
-
-# Model / test hyper-params
-INPUT_SHAPE = (2, 224, 124, 3)
-SEED = 1994
-
-class ImageNetVitModeEquivalenceTest(parameterized.TestCase):
-
- @parameterized.named_parameters(
- dict(
- testcase_name='ViT, p=0.0',
- dropout_rate=0.0),
- dict(
- testcase_name='ViT, p=0.1',
- dropout_rate=0.1),
- )
- def test_forward(self, dropout_rate):
- OrigCls, CustCls = (OriginalVit, CustomVit)
-
-
- # init model
- rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3)
-
- orig_model = OrigCls()
- cust_model = CustCls()
-
- fake_batch = jnp.ones(INPUT_SHAPE)
-
- initial_params_original = orig_model.init({'params': rng},
- fake_batch,
- train=False)
- initial_params_custom = cust_model.init({'params': rng},
- fake_batch,
- train=False)
-
- # fwd
- x = jax.random.normal(data_rng, shape=INPUT_SHAPE)
-
- for mode in ('train', 'eval'):
- train = mode == 'train'
- y1 = orig_model.apply(
- initial_params_original,
- x,
- train=train,
- rngs={'dropout': dropout_rng})
- y2 = cust_model.apply(
- initial_params_custom,
- x,
- train=train,
- dropout_rate=dropout_rate,
- rngs={'dropout': dropout_rng})
-
- assert jnp.allclose(y1, y2, atol=1e-3, rtol=1e-3)
-
- @parameterized.named_parameters(
- dict(testcase_name='UNet, default'),
- )
- def test_default_dropout(self):
- """Test default dropout_rate."""
- OrigCls, CustCls = (OriginalVit, CustomVit)
-
-
- # init model
- rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3)
-
- orig_model = OrigCls()
- cust_model = CustCls()
-
- fake_batch = jnp.ones(INPUT_SHAPE)
-
- initial_params_original = orig_model.init({'params': rng},
- fake_batch,
- train=False)
- initial_params_custom = cust_model.init({'params': rng},
- fake_batch,
- train=False)
-
- # fwd
- x = jax.random.normal(data_rng, INPUT_SHAPE)
-
- for mode in ('train', 'eval'):
- train = mode == 'train'
- y1 = orig_model.apply(
- initial_params_original,
- x,
- train=train,
- rngs={'dropout': dropout_rng})
- y2 = cust_model.apply(
- initial_params_custom, x, train=train, rngs={'dropout': dropout_rng})
-
- assert jnp.allclose(y1, y2, atol=0, rtol=0)
-
-if __name__ == '__main__':
- absltest.main()
\ No newline at end of file
diff --git a/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py b/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py
deleted file mode 100644
index f51eaec7e..000000000
--- a/tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py
+++ /dev/null
@@ -1,138 +0,0 @@
-"""
-Runs fwd pass with random input for FASTMRI U-Net models and compares outputs.
-Run it as:
- python3 tests/dropout_fix/imagenet_vit_pytorch/test_model_equivalence.py
-"""
-
-import itertools
-import os
-
-from absl.testing import absltest
-from absl.testing import parameterized
-import torch
-from torch.testing import assert_close
-
-from algoperf.workloads.imagenet_vit.imagenet_pytorch.models import \
- ViT as OriginalVit
-from algoperf.workloads.imagenet_vit.imagenet_pytorch.models_dropout import \
- ViT as CustomVit
-
-# Model / test hyper-params
-BATCH, C, H, W = 4, 3, 224, 224 # input shape (N,C,H,W)
-WIDTH, DEPTH, HEADS = 256, 4, 8
-DROPOUT_RATE = None
-DEVICE = 'cuda'
-SEED = 1996
-
-os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
-torch.backends.cudnn.benchmark = False
-torch.backends.cudnn.deterministic = True
-torch.use_deterministic_algorithms(True)
-
-
-class ImageNetVitModeEquivalenceTest(parameterized.TestCase):
-
- def fwd_pass(self, orig, cust, dropout_rate):
- x = torch.randn(BATCH, C, H, W, device=DEVICE)
- for mode in ('train', 'eval'):
- getattr(orig, mode)()
- getattr(cust, mode)()
- torch.manual_seed(0)
- y1 = orig(x)
- torch.manual_seed(0)
- y2 = cust(x, dropout_rate)
- assert_close(y1, y2, atol=0, rtol=0)
- if mode == 'eval': # one extra test: omit dropout at eval
- torch.manual_seed(0)
- y2 = cust(x, dropout_rate)
- assert_close(y1, y2, atol=0, rtol=0)
-
- @parameterized.named_parameters(
- dict(testcase_name='p=0.0', dropout_rate=0.0),
- dict(testcase_name='p=0.1', dropout_rate=0.1),
- dict(testcase_name='p=0.6', dropout_rate=0.6),
- dict(testcase_name='p=1.0', dropout_rate=1.0),
- )
- def test_dropout_values(self, dropout_rate):
- """Test different dropout_rates."""
-
- torch.manual_seed(SEED)
- orig = OriginalVit(
- width=WIDTH,
- depth=DEPTH,
- num_heads=HEADS,
- dropout_rate=dropout_rate,
- ).to(DEVICE)
-
- torch.manual_seed(SEED)
- cust = CustomVit(
- width=WIDTH,
- depth=DEPTH,
- num_heads=HEADS,
- ).to(DEVICE)
-
- cust.load_state_dict(orig.state_dict()) # sync weights
- self.fwd_pass(orig, cust, dropout_rate)
-
- @parameterized.named_parameters([
- dict(
- testcase_name=f"GLU={use_glu}_LN={use_post_ln}_MAP={use_map}",
- use_glu=use_glu,
- use_post_ln=use_post_ln,
- use_map=use_map,
- ) for use_glu,
- use_post_ln,
- use_map in itertools.product([False, True], repeat=3)
- ])
- def test_arch(self, use_glu, use_post_ln, use_map):
- """Test different architecture configurations, fixed dropout_rate."""
- dropout_rate = 0.1
-
- torch.manual_seed(SEED)
- orig = OriginalVit(
- width=WIDTH,
- depth=DEPTH,
- num_heads=HEADS,
- use_glu=use_glu,
- use_post_layer_norm=use_post_ln,
- use_map=use_map,
- dropout_rate=dropout_rate,
- ).to(DEVICE)
-
- torch.manual_seed(SEED)
- cust = CustomVit(
- width=WIDTH,
- depth=DEPTH,
- num_heads=HEADS,
- use_glu=use_glu,
- use_post_layer_norm=use_post_ln,
- use_map=use_map,
- ).to(DEVICE)
-
- cust.load_state_dict(orig.state_dict()) # sync weights
- self.fwd_pass(orig, cust, dropout_rate)
-
- @parameterized.named_parameters(
- dict(testcase_name=''),)
- def test_default_dropout(self):
- """Test default dropout_rate."""
-
- torch.manual_seed(SEED)
- orig = OriginalVit(width=WIDTH, depth=DEPTH, num_heads=HEADS).to(DEVICE)
- torch.manual_seed(SEED)
- cust = CustomVit(width=WIDTH, depth=DEPTH, num_heads=HEADS).to(DEVICE)
- cust.load_state_dict(orig.state_dict()) # sync weights
-
- x = torch.randn(BATCH, C, H, W, device=DEVICE)
- for mode in ('train', 'eval'):
- getattr(orig, mode)()
- getattr(cust, mode)()
- torch.manual_seed(0)
- y1 = orig(x)
- torch.manual_seed(0)
- y2 = cust(x)
- assert_close(y1, y2, atol=0, rtol=0)
-
-
-if __name__ == '__main__':
- absltest.main()
diff --git a/tests/dropout_fix/librispeech_conformer_jax/test_model_equivalence.py b/tests/dropout_fix/librispeech_conformer_jax/test_model_equivalence.py
deleted file mode 100644
index c5d849dce..000000000
--- a/tests/dropout_fix/librispeech_conformer_jax/test_model_equivalence.py
+++ /dev/null
@@ -1,117 +0,0 @@
-"""
-Runs fwd pass with random input for FASTMRI U-Net models and compares outputs.
-Run it as:
- python3 tests/dropout_fix/imagenet_vit_jax/test_model_equivalence.py
-"""
-
-import os
-
-from absl.testing import absltest
-from absl.testing import parameterized
-import jax
-import jax.numpy as jnp
-
-from algoperf.workloads.librispeech_conformer.librispeech_jax.models import ConformerConfig as CustClsConfig
-from algoperf.workloads.librispeech_conformer.librispeech_jax.models import Conformer as CustCls
-
-from algoperf.workloads.librispeech_conformer.librispeech_jax.models_ref import ConformerConfig as OrigClsConfig
-from algoperf.workloads.librispeech_conformer.librispeech_jax.models_ref import Conformer as OrigCls
-
-
-# Model / test hyper-params
-INPUT_SHAPE = [(3200,), (3200,)]
-SEED = 1994
-
-class ModeEquivalenceTest(parameterized.TestCase):
-
- @parameterized.named_parameters(
- dict(
- testcase_name='Conformer, p=0.0',
- dropout_rate=0.0),
- dict(
- testcase_name='Conformer, p=0.1',
- dropout_rate=0.1),
- )
- def test_forward(self, dropout_rate):
-
- # init model
- rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3)
-
- orig_model = OrigCls(OrigClsConfig(attention_dropout_rate=dropout_rate,
- attention_residual_dropout_rate = dropout_rate,
- conv_residual_dropout_rate = dropout_rate,
- feed_forward_dropout_rate = dropout_rate,
- feed_forward_residual_dropout_rate = dropout_rate,
- input_dropout_rate=dropout_rate))
- cust_model = CustCls(CustClsConfig())
-
- fake_batch = [jnp.zeros((2, *x), jnp.float32) for x in INPUT_SHAPE]
-
- initial_params_original = orig_model.init({'params': rng},
- *fake_batch,
- train=False)
- initial_params_custom = cust_model.init({'params': rng},
- *fake_batch,
- train=False)
-
- # fwd
- x = [jax.random.normal(data_rng, (2, *x)) for x in INPUT_SHAPE]
-
- for train in [True]:
- (y1, _), _ = orig_model.apply(
- initial_params_original,
- *x,
- train=train,
- rngs={'dropout': dropout_rng},
- mutable=['batch_stats'],)
- (y2, _), _ = cust_model.apply(
- initial_params_custom,
- *x,
- train=train,
- dropout_rate=dropout_rate,
- rngs={'dropout': dropout_rng},
- mutable=['batch_stats'])
-
- assert jnp.allclose(y1, y2)
-
-
-
- @parameterized.named_parameters(
- dict(testcase_name='Conformer, default'),
- )
- def test_default_dropout(self):
- """Test default dropout_rate."""
- # init model
- rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3)
-
- orig_model = OrigCls(OrigClsConfig())
- cust_model = CustCls(CustClsConfig())
-
- fake_batch = [jnp.zeros((2, *x), jnp.float32) for x in INPUT_SHAPE]
-
- initial_params_original = orig_model.init({'params': rng},
- *fake_batch,
- train=False)
- initial_params_custom = cust_model.init({'params': rng},
- *fake_batch,
- train=False)
-
- # fwd
- x = [jax.random.normal(data_rng, (2, *x)) for x in INPUT_SHAPE]
-
- for mode in ('train', 'eval'):
- train = mode == 'train'
- (y1, _), _ = orig_model.apply(
- initial_params_original,
- *x,
- train=train,
- rngs={'dropout': dropout_rng}, mutable=['batch_stats'])
- (y2, _), _ = cust_model.apply(
- initial_params_custom, *x, train=train, rngs={'dropout': dropout_rng}, mutable=['batch_stats'])
-
-
- assert jnp.allclose(y1, y2)
-
-
-if __name__ == '__main__':
- absltest.main()
\ No newline at end of file
diff --git a/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py b/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py
deleted file mode 100644
index 02f3a3d84..000000000
--- a/tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py
+++ /dev/null
@@ -1,84 +0,0 @@
-"""
-Runs fwd pass with random input for LIBRISPEECH Conformer models and compares outputs.
-Run with:
- python3 tests/dropout_fix/librispeech_conformer_pytorch/test_model_equivalence.py
-
-NOTE: we don't test for default dropout_rate values, since they changed.
-"""
-
-import os
-
-from absl.testing import absltest
-from absl.testing import parameterized
-import torch
-from torch.testing import assert_close
-
-from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \
- ConformerConfig as OriginalConfig
-from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \
- ConformerEncoderDecoder as OriginalModel
-from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models_dropout import \
- ConformerConfig as CustomConfig
-from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models_dropout import \
- ConformerEncoderDecoder as CustomModel
-
-N_LAYERS = 3
-B, T = 32, 36_000
-DEVICE = 'cuda'
-
-os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
-torch.backends.cudnn.benchmark = False
-torch.backends.cudnn.deterministic = True
-torch.use_deterministic_algorithms(mode=True)
-SEED = 1996
-
-
-class ConformerEquivalenceTest(parameterized.TestCase):
-
- @parameterized.named_parameters(
- dict(testcase_name='p=0.0', dropout_rate=0.0),
- dict(testcase_name='p=0.2', dropout_rate=0.2),
- dict(testcase_name='p=0.7', dropout_rate=0.7),
- dict(testcase_name='p=1.0', dropout_rate=1.0),
- )
- def test_forward(self, dropout_rate):
-
- torch.manual_seed(SEED)
- orig = OriginalModel(
- OriginalConfig(
- num_encoder_layers=N_LAYERS,
- attention_residual_dropout_rate=dropout_rate,
- conv_residual_dropout_rate=dropout_rate,
- feed_forward_residual_dropout_rate=dropout_rate,
- input_dropout_rate=dropout_rate,
- )).to(DEVICE)
-
- torch.manual_seed(SEED)
- cust = CustomModel(CustomConfig(num_encoder_layers=N_LAYERS)).to(DEVICE)
-
- orig.load_state_dict(cust.state_dict()) # sync weights
-
- x = torch.randn(B, T, device=DEVICE)
- paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE)
-
- for mode in ('train', 'eval'):
- getattr(orig, mode)()
- getattr(cust, mode)()
-
- torch.manual_seed(SEED)
- y1, p1 = orig(x, paddings)
- torch.manual_seed(SEED)
- y2, p2 = cust(x, paddings, dropout_rate=dropout_rate)
-
- assert_close(y1, y2, atol=0, rtol=0)
- assert_close(p1, p2, atol=0, rtol=0)
-
- if mode == 'eval': # one extra test: omit dropout at eval
- torch.manual_seed(SEED)
- y2, p2 = cust(x, paddings)
- assert_close(y1, y2, atol=0, rtol=0)
- assert_close(p1, p2, atol=0, rtol=0)
-
-
-if __name__ == '__main__':
- absltest.main()
diff --git a/tests/dropout_fix/librispeech_deepspeech_jax/test_model_equivalence.py b/tests/dropout_fix/librispeech_deepspeech_jax/test_model_equivalence.py
deleted file mode 100644
index 5155199f5..000000000
--- a/tests/dropout_fix/librispeech_deepspeech_jax/test_model_equivalence.py
+++ /dev/null
@@ -1,113 +0,0 @@
-"""
-Runs fwd pass with random input for FASTMRI U-Net models and compares outputs.
-Run it as:
- python3 tests/dropout_fix/imagenet_vit_jax/test_model_equivalence.py
-"""
-
-import os
-
-from absl.testing import absltest
-from absl.testing import parameterized
-import jax
-import jax.numpy as jnp
-
-from algoperf.workloads.librispeech_deepspeech.librispeech_jax.models import DeepspeechConfig as CustClsConfig
-from algoperf.workloads.librispeech_deepspeech.librispeech_jax.models import Deepspeech as CustCls
-
-from algoperf.workloads.librispeech_deepspeech.librispeech_jax.models_ref import DeepspeechConfig as OrigClsConfig
-from algoperf.workloads.librispeech_deepspeech.librispeech_jax.models_ref import Deepspeech as OrigCls
-
-
-# Model / test hyper-params
-INPUT_SHAPE = [(3200,), (3200,)]
-SEED = 1994
-
-class ModeEquivalenceTest(parameterized.TestCase):
-
- @parameterized.named_parameters(
- dict(
- testcase_name='Conformer, p=0.0',
- dropout_rate=0.0),
- dict(
- testcase_name='Conformer, p=0.1',
- dropout_rate=0.1),
- )
- def test_forward(self, dropout_rate):
-
- # init model
- rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3)
-
- orig_model = OrigCls(OrigClsConfig)
- cust_model = CustCls(CustClsConfig)
-
- fake_batch = [jnp.zeros((2, *x), jnp.float32) for x in INPUT_SHAPE]
-
- initial_params_original = orig_model.init({'params': rng},
- *fake_batch,
- train=False)
- initial_params_custom = cust_model.init({'params': rng},
- *fake_batch,
- train=False)
-
- # fwd
- x = [jax.random.normal(data_rng, (2, *x)) for x in INPUT_SHAPE]
-
- for mode in ('train', 'eval'):
- train = mode == 'train'
- (y1, _), _ = orig_model.apply(
- initial_params_original,
- *x,
- train=train,
- rngs={'dropout': dropout_rng},
- mutable=['batch_stats'],)
- (y2, _), _ = cust_model.apply(
- initial_params_custom,
- *x,
- train=train,
- dropout_rate=dropout_rate,
- rngs={'dropout': dropout_rng},
- mutable=['batch_stats'])
-
- assert jnp.allclose(y1, y2)
-
-
-
- @parameterized.named_parameters(
- dict(testcase_name='Conformer, default'),
- )
- def test_default_dropout(self):
- """Test default dropout_rate."""
- # init model
- rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3)
-
- orig_model = OrigCls(OrigClsConfig)
- cust_model = CustCls(CustClsConfig)
-
- fake_batch = [jnp.zeros((2, *x), jnp.float32) for x in INPUT_SHAPE]
-
- initial_params_original = orig_model.init({'params': rng},
- *fake_batch,
- train=False)
- initial_params_custom = cust_model.init({'params': rng},
- *fake_batch,
- train=False)
-
- # fwd
- x = [jax.random.normal(data_rng, (2, *x)) for x in INPUT_SHAPE]
-
- for mode in ('train', 'eval'):
- train = mode == 'train'
- (y1, _), _ = orig_model.apply(
- initial_params_original,
- *x,
- train=train,
- rngs={'dropout': dropout_rng}, mutable=['batch_stats'])
- (y2, _), _ = cust_model.apply(
- initial_params_custom, *x, train=train, rngs={'dropout': dropout_rng}, mutable=['batch_stats'])
-
-
- assert jnp.allclose(y1, y2)
-
-
-if __name__ == '__main__':
- absltest.main()
\ No newline at end of file
diff --git a/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py b/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py
deleted file mode 100644
index 7d6a94592..000000000
--- a/tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py
+++ /dev/null
@@ -1,124 +0,0 @@
-"""
-Runs fwd pass with random input for LIBRISPEECH Deepspeech models and compares outputs.
-Run with:
- python3 tests/dropout_fix/librispeech_deepspeech_pytorch/test_model_equivalence.py
-
-`dropout_rate` controls the following args:
-- `input_dropout_rate` (if None, 0.1
-- `feed_forward_dropout_rate` (if None, 0.1)
-"""
-
-import os
-
-from absl.testing import absltest
-from absl.testing import parameterized
-import torch
-from torch.testing import assert_close
-
-from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import \
- DeepspeechConfig as OriginalConfig
-from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import \
- DeepspeechEncoderDecoder as OriginalModel
-from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models_dropout import \
- DeepspeechConfig as CustomConfig
-from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models_dropout import \
- DeepspeechEncoderDecoder as CustomModel
-
-B, T = 32, 30_000
-DEVICE = 'cuda'
-TORCH_COMPILE = False
-
-os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
-torch.backends.cudnn.benchmark = False
-torch.backends.cudnn.deterministic = True
-torch.use_deterministic_algorithms(mode=True)
-SEED = 1996
-
-
-class DeepSpeechEquivalenceTest(parameterized.TestCase):
-
- @parameterized.named_parameters(
- dict(testcase_name='p=0.0', dropout_rate=0.0),
- dict(testcase_name='p=0.2', dropout_rate=0.2),
- dict(testcase_name='p=0.7', dropout_rate=0.7),
- dict(testcase_name='p=1.0', dropout_rate=1.0),
- )
- def test_forward(self, dropout_rate):
- """Test different dropout_rate values."""
-
- torch.manual_seed(SEED)
- orig = OriginalModel(
- OriginalConfig(
- num_lstm_layers=2,
- num_ffn_layers=2,
- input_dropout_rate=dropout_rate,
- feed_forward_dropout_rate=dropout_rate,
- )).to(DEVICE)
-
- torch.manual_seed(SEED)
- cust = CustomModel(CustomConfig(
- num_lstm_layers=2,
- num_ffn_layers=2,
- )).to(DEVICE)
-
- orig.load_state_dict(cust.state_dict()) # sync weights
-
- if TORCH_COMPILE:
- orig = torch.compile(orig)
- cust = torch.compile(cust)
-
- x = torch.randn(B, T, device=DEVICE)
- paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE)
-
- for mode in ('train', 'eval'):
- getattr(orig, mode)()
- getattr(cust, mode)()
-
- torch.manual_seed(SEED)
- y1, p1 = orig(x, paddings)
-
- torch.manual_seed(SEED)
- y2, p2 = cust(x, paddings, dropout_rate=dropout_rate)
-
- assert_close(y1, y2, atol=0, rtol=0)
- assert_close(p1, p2, atol=0, rtol=0)
-
- if mode == 'eval': # one extra test: omit dropout at eval
- torch.manual_seed(SEED)
- y2, p2 = cust(x, paddings)
- assert_close(y1, y2, atol=0, rtol=0)
- assert_close(p1, p2, atol=0, rtol=0)
-
- @parameterized.named_parameters(
- dict(testcase_name=''),)
- def test_default_dropout(self):
- """Test default dropout_rate."""
-
- torch.manual_seed(SEED)
- orig = OriginalModel(OriginalConfig(num_lstm_layers=2,
- num_ffn_layers=2)).to(DEVICE)
- torch.manual_seed(SEED)
- cust = CustomModel(CustomConfig(num_lstm_layers=2,
- num_ffn_layers=2)).to(DEVICE)
- orig.load_state_dict(cust.state_dict()) # sync weights
-
- if TORCH_COMPILE:
- orig = torch.compile(orig)
- cust = torch.compile(cust)
-
- x = torch.randn(B, T, device=DEVICE)
- paddings = torch.zeros(B, T, dtype=torch.float32, device=DEVICE)
-
- for mode in ('train', 'eval'):
- getattr(orig, mode)()
- getattr(cust, mode)()
- torch.manual_seed(SEED)
- y1, p1 = orig(x, paddings)
- torch.manual_seed(SEED)
- y2, p2 = cust(x, paddings)
- assert_close(y1, y2, atol=0, rtol=0)
- assert_close(p1, p2, atol=0, rtol=0)
-
-
-if __name__ == '__main__':
- absltest.main()
diff --git a/tests/dropout_fix/ogbg_jax/test_model_equivalence.py b/tests/dropout_fix/ogbg_jax/test_model_equivalence.py
deleted file mode 100644
index 1b5b8a180..000000000
--- a/tests/dropout_fix/ogbg_jax/test_model_equivalence.py
+++ /dev/null
@@ -1,133 +0,0 @@
-"""
-Runs fwd pass with random input for OGBG
-"""
-
-import os
-
-import jraph
-
-from absl.testing import absltest
-from absl.testing import parameterized
-import jax
-import jax.numpy as jnp
-
-from algoperf.workloads.ogbg.ogbg_jax.models_ref import \
- GNN as OrigCls
-from algoperf.workloads.ogbg.ogbg_jax.models import \
- GNN as CustCls
-
-# Model / test hyper-params
-SEED = 1994
-
-class ModeEquivalenceTest(parameterized.TestCase):
-
- @parameterized.named_parameters(
- dict(
- testcase_name='OGBG, p=0.0',
- dropout_rate=0.0),
- dict(
- testcase_name='OGBG, p=0.1',
- dropout_rate=0.1),
- )
- def test_forward(self, dropout_rate):
- # init model
- rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3)
-
- orig_model = OrigCls(num_outputs=128, dropout_rate=dropout_rate)
- cust_model = CustCls(num_outputs=128)
-
- fake_batch = jraph.GraphsTuple(
- n_node=jnp.asarray([1]),
- n_edge=jnp.asarray([1]),
- nodes=jnp.ones((1, 9)),
- edges=jnp.ones((1, 3)),
- globals=jnp.zeros((1, 128)),
- senders=jnp.asarray([0]),
- receivers=jnp.asarray([0]))
-
- initial_params_original = orig_model.init({'params': rng},
- fake_batch,
- train=False)
- initial_params_custom = cust_model.init({'params': rng},
- fake_batch,
- train=False)
-
- # fwd
- x = jraph.GraphsTuple(
- n_node=jnp.asarray([1]),
- n_edge=jnp.asarray([1]),
- nodes=jnp.ones((1, 9)),
- edges=jnp.ones((1, 3)),
- globals=jnp.zeros((1, 128)),
- senders=jnp.asarray([0]),
- receivers=jnp.asarray([0]))
-
- for mode in ('train', 'eval'):
- train = mode == 'train'
- y1 = orig_model.apply(
- initial_params_original,
- x,
- train=train,
- rngs={'dropout': dropout_rng})
- y2 = cust_model.apply(
- initial_params_custom,
- x,
- train=train,
- dropout_rate=dropout_rate,
- rngs={'dropout': dropout_rng})
-
- assert jnp.allclose(y1, y2, atol=1e-3, rtol=1e-3)
-
- @parameterized.named_parameters(
- dict(testcase_name='OGBG, default'),
- )
- def test_default_dropout(self):
- """Test default dropout_rate."""
-
-
- # init model
- rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3)
-
- orig_model = OrigCls(num_outputs=128)
- cust_model = CustCls(num_outputs=128)
-
- fake_batch = jraph.GraphsTuple(
- n_node=jnp.asarray([1]),
- n_edge=jnp.asarray([1]),
- nodes=jnp.ones((1, 9)),
- edges=jnp.ones((1, 3)),
- globals=jnp.zeros((1, 128)),
- senders=jnp.asarray([0]),
- receivers=jnp.asarray([0]))
-
- initial_params_original = orig_model.init({'params': rng},
- fake_batch,
- train=False)
- initial_params_custom = cust_model.init({'params': rng},
- fake_batch,
- train=False)
-
- # fwd
- x = jraph.GraphsTuple(
- n_node=jnp.asarray([1]),
- n_edge=jnp.asarray([1]),
- nodes=jnp.ones((1, 9)),
- edges=jnp.ones((1, 3)),
- globals=jnp.zeros((1, 128)),
- senders=jnp.asarray([0]),
- receivers=jnp.asarray([0]))
-
- for mode in ('train', 'eval'):
- train = mode == 'train'
- y1 = orig_model.apply(
- initial_params_original,
- x,
- train=train,
- rngs={'dropout': dropout_rng})
- y2 = cust_model.apply(
- initial_params_custom, x, train=train, rngs={'dropout': dropout_rng})
-
- assert jnp.allclose(y1, y2, atol=0, rtol=0)
-
-if __name__ == '__main__':
- absltest.main()
\ No newline at end of file
diff --git a/tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py b/tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py
deleted file mode 100644
index 3b3feb680..000000000
--- a/tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py
+++ /dev/null
@@ -1,118 +0,0 @@
-"""
-Runs fwd pass with random graphs for OGBG GNN models and compares outputs.
-Run with:
- python3 tests/dropout_fix/ogbg_pytorch/test_model_equivalence.py
-"""
-
-import os
-import random
-
-from absl.testing import absltest
-from absl.testing import parameterized
-from jraph import GraphsTuple
-import numpy as np
-import torch
-from torch.testing import assert_close
-
-from algoperf.workloads.ogbg.ogbg_pytorch.models import GNN as OriginalModel
-from algoperf.workloads.ogbg.ogbg_pytorch.models_dropout import \
- GNN as CustomModel
-
-B, N, E = 8, 20, 40 # graphs, nodes/graph, edges/graph
-NODE_FDIM, EDGE_FDIM = 9, 3 # expected feature dims
-DEVICE = 'cuda'
-
-os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
-torch.backends.cudnn.benchmark = False
-torch.backends.cudnn.deterministic = True
-torch.use_deterministic_algorithms(True)
-SEED = 1996
-
-
-def _rand_graph():
- total_nodes, total_edges = B * N, B * E
- nodes = torch.randn(total_nodes, NODE_FDIM, device=DEVICE)
- edges = torch.randn(total_edges, EDGE_FDIM, device=DEVICE)
- senders, receivers = [], []
- for i in range(B):
- offset = i * N
- s = torch.randint(N, (E,), device=DEVICE) + offset
- r = torch.randint(N, (E,), device=DEVICE) + offset
- senders.append(s), receivers.append(r)
- senders = torch.cat(senders)
- receivers = torch.cat(receivers)
- n_node = torch.full((B,), N, device=DEVICE, dtype=torch.int32)
- n_edge = torch.full((B,), E, device=DEVICE, dtype=torch.int32)
- return GraphsTuple(nodes, edges, receivers, senders, None, n_node, n_edge)
-
-
-class GNNEquivalenceTest(parameterized.TestCase):
-
- @parameterized.named_parameters(
- dict(testcase_name='0.0', dropout_rate=0.0),
- dict(testcase_name='0.2', dropout_rate=0.2),
- dict(testcase_name='0.7', dropout_rate=0.7),
- dict(testcase_name='1.0', dropout_rate=1.0),
- )
- def test_forward(self, dropout_rate):
- """Test different dropout_rates."""
-
- orig = OriginalModel(dropout_rate=dropout_rate).to(DEVICE)
- cust = CustomModel().to(DEVICE)
- orig.load_state_dict(cust.state_dict()) # sync weights
-
- graph = _rand_graph()
-
- for mode in ('train', 'eval'):
- getattr(orig, mode)()
- getattr(cust, mode)()
-
- torch.manual_seed(SEED)
- random.seed(SEED)
- np.random.seed(SEED)
- y1 = orig(graph)
-
- torch.manual_seed(SEED)
- random.seed(SEED)
- np.random.seed(SEED)
- y2 = cust(graph, dropout_rate=dropout_rate)
-
- assert_close(y1, y2, atol=0, rtol=0)
-
- if mode == 'eval': # one extra test: omit dropout at eval
- torch.manual_seed(SEED)
- random.seed(SEED)
- np.random.seed(SEED)
- y2 = cust(graph)
- assert_close(y1, y2, atol=0, rtol=0)
-
- @parameterized.named_parameters(
- dict(testcase_name=''),)
- def test_default_dropout(self):
- """Test default dropout_rate."""
-
- orig = OriginalModel().to(DEVICE)
- cust = CustomModel().to(DEVICE)
- orig.load_state_dict(cust.state_dict()) # sync weights
-
- graph = _rand_graph()
-
- for mode in ('train', 'eval'):
- getattr(orig, mode)()
- getattr(cust, mode)()
-
- torch.manual_seed(SEED)
- random.seed(SEED)
- np.random.seed(SEED)
- y1 = orig(graph)
-
- torch.manual_seed(SEED)
- random.seed(SEED)
- np.random.seed(SEED)
- y2 = cust(graph)
-
- assert_close(y1, y2, atol=0, rtol=0)
-
-
-if __name__ == '__main__':
- absltest.main()
diff --git a/tests/dropout_fix/wmt_jax/test_model_equivalence.py b/tests/dropout_fix/wmt_jax/test_model_equivalence.py
deleted file mode 100644
index d420bf691..000000000
--- a/tests/dropout_fix/wmt_jax/test_model_equivalence.py
+++ /dev/null
@@ -1,113 +0,0 @@
-"""
-Runs fwd pass with random input for FASTMRI U-Net models and compares outputs.
-Run it as:
- python3 tests/dropout_fix/imagenet_vit_jax/test_model_equivalence.py
-"""
-
-import os
-
-from absl.testing import absltest
-from absl.testing import parameterized
-import jax
-import jax.numpy as jnp
-
-from algoperf.workloads.wmt.wmt_jax.models import TransformerConfig as CustClsConfig
-from algoperf.workloads.wmt.wmt_jax.models import Transformer as CustCls
-
-from algoperf.workloads.wmt.wmt_jax.models_ref import TransformerConfig as OrigClsConfig
-from algoperf.workloads.wmt.wmt_jax.models_ref import Transformer as OrigCls
-
-
-# Model / test hyper-params
-SEED = 1994
-
-class ModeEquivalenceTest(parameterized.TestCase):
-
- @parameterized.named_parameters(
- dict(
- testcase_name='WMT, p=0.0',
- dropout_rate=0.0,
- train=True),
- dict(
- testcase_name='WMT p=0.1',
- dropout_rate=0.1,
- train=False),
- )
- def test_forward(self, dropout_rate, train):
-
- # init model
- rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3)
-
- orig_model = OrigCls(OrigClsConfig(deterministic=not train, attention_dropout_rate=dropout_rate, dropout_rate=dropout_rate))
- cust_model = CustCls(CustClsConfig(deterministic=not train))
-
- init_fake_batch_size = 8
- input_shape = (init_fake_batch_size, 256)
- target_shape = (init_fake_batch_size, 256)
-
- initial_params_original = orig_model.init({'params': rng},
- jnp.ones(input_shape, jnp.float32),
- jnp.ones(target_shape, jnp.float32))
- initial_params_custom = cust_model.init({'params': rng},
- jnp.ones(input_shape, jnp.float32),
- jnp.ones(target_shape, jnp.float32),)
-
- # fwd
-
- y1 = orig_model.apply(
- initial_params_original,
- jnp.ones(input_shape, jnp.float32),
- jnp.ones(target_shape, jnp.float32),
- rngs={'dropout': dropout_rng})
-
- y2 = cust_model.apply(
- initial_params_custom,
- jnp.ones(input_shape, jnp.float32),
- jnp.ones(target_shape, jnp.float32),
- dropout_rate=dropout_rate,
- rngs={'dropout': dropout_rng})
-
- assert jnp.allclose(y1, y2)
-
-
-
- @parameterized.named_parameters(
- dict(testcase_name='WMT, default train', train=True),
- dict(testcase_name='WMT, default eval', train=False),
- )
- def test_default_dropout(self, train):
- """Test default dropout_rate."""
- rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3)
- orig_model = OrigCls(OrigClsConfig(deterministic=not train))
- cust_model = CustCls(CustClsConfig(deterministic=not train))
-
- init_fake_batch_size = 8
- input_shape = (init_fake_batch_size, 256)
- target_shape = (init_fake_batch_size, 256)
-
- initial_params_original = orig_model.init({'params': rng},
- jnp.ones(input_shape, jnp.float32),
- jnp.ones(target_shape, jnp.float32))
- initial_params_custom = cust_model.init({'params': rng},
- jnp.ones(input_shape, jnp.float32),
- jnp.ones(target_shape, jnp.float32))
-
- # fwd
-
- y1 = orig_model.apply(
- initial_params_original,
- jnp.ones(input_shape, jnp.float32),
- jnp.ones(target_shape, jnp.float32),
- rngs={'dropout': dropout_rng})
-
- y2 = cust_model.apply(
- initial_params_custom,
- jnp.ones(input_shape, jnp.float32),
- jnp.ones(target_shape, jnp.float32),
- rngs={'dropout': dropout_rng})
-
- assert jnp.allclose(y1, y2)
-
-
-if __name__ == '__main__':
- absltest.main()
\ No newline at end of file
diff --git a/tests/dropout_fix/wmt_pytorch/test_model_equivalence.py b/tests/dropout_fix/wmt_pytorch/test_model_equivalence.py
deleted file mode 100644
index 03f289a68..000000000
--- a/tests/dropout_fix/wmt_pytorch/test_model_equivalence.py
+++ /dev/null
@@ -1,121 +0,0 @@
-"""
-Runs fwd pass with random input for WMT Transformer models and compares outputs.
-Run with:
- python3 tests/dropout_fix/wmt_pytorch/test_model_equivalence.py
-"""
-
-import os
-import random
-
-from absl.testing import absltest
-from absl.testing import parameterized
-import numpy as np
-import torch
-from torch.testing import assert_close
-
-from algoperf.workloads.wmt.wmt_pytorch.models import \
- Transformer as OriginalModel
-from algoperf.workloads.wmt.wmt_pytorch.models_dropout import \
- Transformer as CustomModel
-
-B, SRC_LEN, TGT_LEN, NTOK = 16, 80, 80, 32_000
-DEVICE = "cuda"
-SEED = 1996
-
-os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
-torch.backends.cudnn.benchmark = False
-torch.backends.cudnn.deterministic = True
-torch.use_deterministic_algorithms(True)
-
-
-def _rand_tokens(bs, seqlen):
- return torch.randint(1, NTOK, (bs, seqlen), device=DEVICE)
-
-
-class TransformerEquivalenceTest(parameterized.TestCase):
-
- @parameterized.named_parameters(
- # NOTE: removed dropout=1.0 since it will generate nan in scaled_dot_product_attention
- dict(testcase_name="0.0", dropout_rate=0.0, compile=False),
- dict(testcase_name="0.2", dropout_rate=0.2, compile=False),
- dict(testcase_name="0.7", dropout_rate=0.7, compile=False),
- dict(testcase_name="p=0.0_compile", dropout_rate=0.0, compile=True),
- dict(testcase_name="p=0.2_compile", dropout_rate=0.2, compile=True),
- dict(testcase_name="p=0.7_compile", dropout_rate=0.7, compile=True),
- )
- def test_dropout_value(self, dropout_rate, compile):
-
- orig = OriginalModel(
- dropout_rate=dropout_rate,
- attention_dropout_rate=dropout_rate).to(DEVICE)
- cust = CustomModel().to(DEVICE)
-
- orig.load_state_dict(cust.state_dict()) # sync weights
-
- if compile:
- orig = torch.compile(orig)
- cust = torch.compile(cust)
-
- src = _rand_tokens(B, SRC_LEN)
- tgt = _rand_tokens(B, TGT_LEN)
-
- for mode in ("train", "eval"):
- getattr(orig, mode)()
- getattr(cust, mode)()
-
- torch.manual_seed(SEED)
- random.seed(SEED)
- np.random.seed(SEED)
- y1 = orig(src, tgt)
-
- torch.manual_seed(SEED)
- random.seed(SEED)
- np.random.seed(SEED)
- y2 = cust(src, tgt, dropout_rate=dropout_rate)
-
- assert_close(y1, y2, atol=0, rtol=0)
-
- if mode == 'eval': # one extra test: omit dropout at eval
- torch.manual_seed(SEED)
- random.seed(SEED)
- np.random.seed(SEED)
- y2 = cust(src, tgt)
- assert_close(y1, y2, atol=0, rtol=0)
-
- @parameterized.named_parameters(
- dict(testcase_name="default", compile=False),
- dict(testcase_name="default_compile", compile=True),
- )
- def test_default(self, compile):
-
- orig = OriginalModel().to(DEVICE)
- cust = CustomModel().to(DEVICE)
-
- orig.load_state_dict(cust.state_dict()) # sync weights
-
- if compile:
- orig = torch.compile(orig)
- cust = torch.compile(cust)
-
- src = _rand_tokens(B, SRC_LEN)
- tgt = _rand_tokens(B, TGT_LEN)
-
- for mode in ("train", "eval"):
- getattr(orig, mode)()
- getattr(cust, mode)()
-
- torch.manual_seed(SEED)
- random.seed(SEED)
- np.random.seed(SEED)
- y1 = orig(src, tgt)
-
- torch.manual_seed(SEED)
- random.seed(SEED)
- np.random.seed(SEED)
- y2 = cust(src, tgt)
-
- assert_close(y1, y2, atol=0, rtol=0)
-
-
-if __name__ == "__main__":
- absltest.main()
From 66f5ed320226e81c39394053386cf09a864bf881 Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Wed, 25 Jun 2025 05:08:11 +0000
Subject: [PATCH 086/123] fix formatting
---
tests/test_jax_utils.py | 33 +++++++++++++++++----------------
1 file changed, 17 insertions(+), 16 deletions(-)
diff --git a/tests/test_jax_utils.py b/tests/test_jax_utils.py
index 04713915f..8a156149b 100644
--- a/tests/test_jax_utils.py
+++ b/tests/test_jax_utils.py
@@ -3,18 +3,19 @@
Run it as: pytest
"""
+from functools import partial
import os
from absl.testing import absltest
from absl.testing import parameterized
+import flax.linen as nn
import jax
import jax.numpy as jnp
-import flax.linen as nn
+from jax.tree_util import tree_leaves
+from jax.tree_util import tree_map
+from jax.tree_util import tree_structure
-from jax.tree_util import tree_structure, tree_leaves, tree_map
from algoperf.jax_utils import Dropout
-from functools import partial
-
SEED = 1996
DEFAULT_DROPOUT = 0.5
@@ -213,25 +214,25 @@ def test_jitted_updates(self, dropout_rate, mode):
jitted_custom_apply = jax.jit(
partial(cust_model.apply), static_argnames=['train'])
-
def multiple_fwd_passes_custom_layer():
- for d in [i * 0.1 * dropout_rate for i in range(0, 11)]:
- y2 = jitted_custom_apply(
- initial_variables_custom,
- x,
- train=train,
- dropout_rate=d,
- rngs={"dropout": dropout_rng},
- )
- return y2
+ for d in [i * 0.1 * dropout_rate for i in range(0, 11)]:
+ y2 = jitted_custom_apply(
+ initial_variables_custom,
+ x,
+ train=train,
+ dropout_rate=d,
+ rngs={"dropout": dropout_rng},
+ )
+ return y2
def multiple_fwd_passes_original_layer():
- for d in [i * 0.1 * dropout_rate for i in range(0, 11)]:
- y1 = jitted_original_apply(
+ for d in [i * 0.1 * dropout_rate for i in range(0, 11)]:
+ y1 = jitted_original_apply(
initial_variables_original,
x,
train=train,
rngs={"dropout": dropout_rng})
+
if __name__ == "__main__":
absltest.main()
From 62b1cc97b346110574c086f923f987685f3642c1 Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Wed, 25 Jun 2025 05:24:36 +0000
Subject: [PATCH 087/123] remove reference model implementations used for
testing
---
.../criteo1tb/criteo1tb_jax/models_ref.py | 219 ------
.../fastmri/fastmri_jax/models_ref.py | 220 ------
.../imagenet_jax/models_ref.py | 132 ----
.../imagenet_vit/imagenet_jax/models_ref.py | 235 ------
.../librispeech_jax/models_ref.py | 712 ------------------
.../librispeech_jax/models_ref.py | 525 -------------
.../workloads/ogbg/ogbg_jax/models_ref.py | 84 ---
algoperf/workloads/wmt/wmt_jax/models_ref.py | 604 ---------------
8 files changed, 2731 deletions(-)
delete mode 100644 algoperf/workloads/criteo1tb/criteo1tb_jax/models_ref.py
delete mode 100644 algoperf/workloads/fastmri/fastmri_jax/models_ref.py
delete mode 100644 algoperf/workloads/imagenet_resnet/imagenet_jax/models_ref.py
delete mode 100644 algoperf/workloads/imagenet_vit/imagenet_jax/models_ref.py
delete mode 100644 algoperf/workloads/librispeech_conformer/librispeech_jax/models_ref.py
delete mode 100644 algoperf/workloads/librispeech_deepspeech/librispeech_jax/models_ref.py
delete mode 100644 algoperf/workloads/ogbg/ogbg_jax/models_ref.py
delete mode 100644 algoperf/workloads/wmt/wmt_jax/models_ref.py
diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/models_ref.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/models_ref.py
deleted file mode 100644
index d0352e290..000000000
--- a/algoperf/workloads/criteo1tb/criteo1tb_jax/models_ref.py
+++ /dev/null
@@ -1,219 +0,0 @@
-"""A JAX implementation of DLRM-Small."""
-
-from typing import Sequence
-
-import flax.linen as nn
-from jax import nn as jnn
-import jax.numpy as jnp
-
-
-class DLRMResNet(nn.Module):
- """Define a DLRMResNet model.
-
- Parameters:
- vocab_size: the size of a single unified embedding table.
- mlp_bottom_dims: dimensions of dense layers of the bottom mlp.
- mlp_top_dims: dimensions of dense layers of the top mlp.
- num_dense_features: number of dense features as the bottom mlp input.
- embed_dim: embedding dimension.
- """
-
- vocab_size: int = 32 * 128 * 1024 # 4_194_304
- num_dense_features: int = 13
- mlp_bottom_dims: Sequence[int] = (256, 256, 256)
- mlp_top_dims: Sequence[int] = (256, 256, 256, 256, 1)
- embed_dim: int = 128
- dropout_rate: float = 0.1
- use_layer_norm: bool = False # Unused.
- embedding_init_multiplier: float = None # Unused
-
- @nn.compact
- def __call__(self, x, train):
- bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1)
- cat_features = jnp.asarray(cat_features, dtype=jnp.int32)
-
- # bottom mlp
- mlp_bottom_dims = self.mlp_bottom_dims
-
- bot_mlp_input = nn.Dense(
- mlp_bottom_dims[0],
- kernel_init=jnn.initializers.glorot_uniform(),
- bias_init=jnn.initializers.normal(stddev=1.0 / mlp_bottom_dims[0]**0.5),
- )(
- bot_mlp_input)
- bot_mlp_input = nn.relu(bot_mlp_input)
-
- for dense_dim in mlp_bottom_dims[1:]:
- x = nn.Dense(
- dense_dim,
- kernel_init=jnn.initializers.glorot_uniform(),
- bias_init=jnn.initializers.normal(stddev=1.0 / dense_dim**0.5),
- )(
- bot_mlp_input)
- bot_mlp_input += nn.relu(x)
-
- base_init_fn = jnn.initializers.uniform(scale=1.0)
- # Embedding table init and lookup for a single unified table.
- idx_lookup = jnp.reshape(cat_features, [-1]) % self.vocab_size
-
- def scaled_init(key, shape, dtype=jnp.float_):
- return base_init_fn(key, shape, dtype) / jnp.sqrt(self.vocab_size)
-
- embedding_table = self.param('embedding_table',
- scaled_init, [self.vocab_size, self.embed_dim])
-
- embed_features = embedding_table[idx_lookup]
- batch_size = bot_mlp_input.shape[0]
- embed_features = jnp.reshape(embed_features,
- (batch_size, 26 * self.embed_dim))
- top_mlp_input = jnp.concatenate([bot_mlp_input, embed_features], axis=1)
- mlp_input_dim = top_mlp_input.shape[1]
- mlp_top_dims = self.mlp_top_dims
- num_layers_top = len(mlp_top_dims)
- top_mlp_input = nn.Dense(
- mlp_top_dims[0],
- kernel_init=jnn.initializers.normal(
- stddev=jnp.sqrt(2.0 / (mlp_input_dim + mlp_top_dims[0]))),
- bias_init=jnn.initializers.normal(
- stddev=jnp.sqrt(1.0 / mlp_top_dims[0])))(
- top_mlp_input)
- top_mlp_input = nn.relu(top_mlp_input)
- for layer_idx, fan_out in list(enumerate(mlp_top_dims))[1:-1]:
- fan_in = mlp_top_dims[layer_idx - 1]
- x = nn.Dense(
- fan_out,
- kernel_init=jnn.initializers.normal(
- stddev=jnp.sqrt(2.0 / (fan_in + fan_out))),
- bias_init=jnn.initializers.normal(
- stddev=jnp.sqrt(1.0 / mlp_top_dims[layer_idx])))(
- top_mlp_input)
- x = nn.relu(x)
- if self.dropout_rate and layer_idx == num_layers_top - 2:
- x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x)
- top_mlp_input += x
- # In the DLRM model the last layer width is always 1. We can hardcode that
- # below.
- logits = nn.Dense(
- 1,
- kernel_init=jnn.initializers.normal(
- stddev=jnp.sqrt(2.0 / (mlp_top_dims[-2] + 1))),
- bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0)))(
- top_mlp_input)
- return logits
-
-
-def dot_interact(concat_features):
- """Performs feature interaction operation between dense or sparse features.
- Input tensors represent dense or sparse features.
- Pre-condition: The tensors have been stacked along dimension 1.
- Args:
- concat_features: Array of features with shape [B, n_features, feature_dim].
- Returns:
- activations: Array representing interacted features.
- """
- batch_size = concat_features.shape[0]
-
- # Interact features, select upper or lower-triangular portion, and reshape.
- xactions = jnp.matmul(concat_features,
- jnp.transpose(concat_features, [0, 2, 1]))
- feature_dim = xactions.shape[-1]
-
- indices = jnp.array(jnp.triu_indices(feature_dim))
- num_elems = indices.shape[1]
- indices = jnp.tile(indices, [1, batch_size])
- indices0 = jnp.reshape(
- jnp.tile(jnp.reshape(jnp.arange(batch_size), [-1, 1]), [1, num_elems]),
- [1, -1])
- indices = tuple(jnp.concatenate((indices0, indices), 0))
- activations = xactions[indices]
- activations = jnp.reshape(activations, [batch_size, -1])
- return activations
-
-
-class DlrmSmall(nn.Module):
- """Define a DLRM-Small model.
-
- Parameters:
- vocab_size: vocab size of embedding table.
- num_dense_features: number of dense features as the bottom mlp input.
- mlp_bottom_dims: dimensions of dense layers of the bottom mlp.
- mlp_top_dims: dimensions of dense layers of the top mlp.
- embed_dim: embedding dimension.
- """
-
- vocab_size: int = 32 * 128 * 1024 # 4_194_304.
- num_dense_features: int = 13
- mlp_bottom_dims: Sequence[int] = (512, 256, 128)
- mlp_top_dims: Sequence[int] = (1024, 1024, 512, 256, 1)
- embed_dim: int = 128
- dropout_rate: float = 0.0
- use_layer_norm: bool = False
- embedding_init_multiplier: float = None
-
- @nn.compact
- def __call__(self, x, train):
- bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1)
- cat_features = jnp.asarray(cat_features, dtype=jnp.int32)
-
- # Bottom MLP.
- for dense_dim in self.mlp_bottom_dims:
- bot_mlp_input = nn.Dense(
- dense_dim,
- kernel_init=jnn.initializers.glorot_uniform(),
- bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / dense_dim)),
- )(
- bot_mlp_input)
- bot_mlp_input = nn.relu(bot_mlp_input)
- if self.use_layer_norm:
- bot_mlp_input = nn.LayerNorm()(bot_mlp_input)
- bot_mlp_output = bot_mlp_input
- batch_size = bot_mlp_output.shape[0]
- feature_stack = jnp.reshape(bot_mlp_output,
- [batch_size, -1, self.embed_dim])
-
- # Embedding table look-up.
- idx_lookup = jnp.reshape(cat_features, [-1]) % self.vocab_size
-
- if self.embedding_init_multiplier is None:
- scale = 1 / jnp.sqrt(self.vocab_size)
- else:
- scale = self.embedding_init_multiplier
-
- def scaled_init(key, shape, dtype=jnp.float_):
- return jnn.initializers.uniform(scale=1.0)(key, shape, dtype) * scale
-
- embedding_table = self.param('embedding_table',
- scaled_init, [self.vocab_size, self.embed_dim])
-
- idx_lookup = jnp.reshape(idx_lookup, [-1])
- embed_features = embedding_table[idx_lookup]
- embed_features = jnp.reshape(embed_features,
- [batch_size, -1, self.embed_dim])
- if self.use_layer_norm:
- embed_features = nn.LayerNorm()(embed_features)
- feature_stack = jnp.concatenate([feature_stack, embed_features], axis=1)
- dot_interact_output = dot_interact(concat_features=feature_stack)
- top_mlp_input = jnp.concatenate([bot_mlp_output, dot_interact_output],
- axis=-1)
- mlp_input_dim = top_mlp_input.shape[1]
- mlp_top_dims = self.mlp_top_dims
- num_layers_top = len(mlp_top_dims)
- for layer_idx, fan_out in enumerate(mlp_top_dims):
- fan_in = mlp_input_dim if layer_idx == 0 else mlp_top_dims[layer_idx - 1]
- top_mlp_input = nn.Dense(
- fan_out,
- kernel_init=jnn.initializers.normal(
- stddev=jnp.sqrt(2.0 / (fan_in + fan_out))),
- bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / fan_out)))(
- top_mlp_input)
- if layer_idx < (num_layers_top - 1):
- top_mlp_input = nn.relu(top_mlp_input)
- if self.use_layer_norm:
- top_mlp_input = nn.LayerNorm()(top_mlp_input)
- if (self.dropout_rate is not None and self.dropout_rate > 0.0 and
- layer_idx == num_layers_top - 2):
- top_mlp_input = nn.Dropout(
- rate=self.dropout_rate, deterministic=not train)(
- top_mlp_input)
- logits = top_mlp_input
- return logits
\ No newline at end of file
diff --git a/algoperf/workloads/fastmri/fastmri_jax/models_ref.py b/algoperf/workloads/fastmri/fastmri_jax/models_ref.py
deleted file mode 100644
index a2d56a4b4..000000000
--- a/algoperf/workloads/fastmri/fastmri_jax/models_ref.py
+++ /dev/null
@@ -1,220 +0,0 @@
-"""Jax / Flax implementation of FastMRI U-Net.
-
-Forked from
-https://github.com/google/init2winit/blob/master/init2winit/model_lib/unet.py
-
-Original implementation:
-github.com/facebookresearch/fastMRI/blob/main/fastmri/models/unet.py
-
-Training:
-github.com/facebookresearch/fastMRI/blob/main/fastmri/pl_modules/unet_module.py
-
-Data:
-github.com/facebookresearch/fastMRI/tree/main/fastmri/data
-"""
-import functools
-from typing import Optional
-
-import flax.linen as nn
-import jax
-import jax.numpy as jnp
-
-
-def _instance_norm2d(x, axes, epsilon=1e-5):
- # promote x to at least float32, this avoids half precision computation
- # but preserves double or complex floating points
- x = jnp.asarray(x, jnp.promote_types(jnp.float32, jnp.result_type(x)))
- mean = jnp.mean(x, axes)
- mean2 = jnp.mean(jnp.square(x), axes)
- # mean2 - _abs_sq(mean) is not guaranteed to be non-negative due
- # to floating point round-off errors.
- var = jnp.maximum(0., mean2 - jnp.square(mean))
- stats_shape = list(x.shape)
- for axis in axes:
- stats_shape[axis] = 1
- mean = mean.reshape(stats_shape)
- var = var.reshape(stats_shape)
- y = x - mean
- mul = jnp.sqrt(var + epsilon)
- y /= mul
- return y
-
-
-class UNet(nn.Module):
- """Jax / Flax implementation of a U-Net model.
-
- O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks
- for biomedical image segmentation. In International Conference on Medical
- image computing and computer-assisted intervention, pages 234–241.
- Springer, 2015.
-
- out_channels: Number of channels in the output to the U-Net model.
- channels: Number of output channels of the first convolution layer.
- num_pool_layers: Number of down-sampling and up-sampling layers.
- dropout_rate: Dropout probability.
- """
- num_channels: int = 32
- num_pool_layers: int = 4
- out_channels = 1
- dropout_rate: Optional[float] = 0.0 # If None, defaults to 0.0.
- use_tanh: bool = False
- use_layer_norm: bool = False
-
- @nn.compact
- def __call__(self, x, train=True):
- dropout_rate = self.dropout_rate
- if dropout_rate is None:
- dropout_rate = 0.0
-
- # pylint: disable=invalid-name
- _ConvBlock = functools.partial(
- ConvBlock,
- dropout_rate=dropout_rate,
- use_tanh=self.use_tanh,
- use_layer_norm=self.use_layer_norm)
- _TransposeConvBlock = functools.partial(
- TransposeConvBlock,
- use_tanh=self.use_tanh,
- use_layer_norm=self.use_layer_norm)
-
- down_sample_layers = [_ConvBlock(self.num_channels)]
-
- ch = self.num_channels
- for _ in range(self.num_pool_layers - 1):
- down_sample_layers.append(_ConvBlock(ch * 2))
- ch *= 2
- conv = _ConvBlock(ch * 2)
-
- up_conv = []
- up_transpose_conv = []
- for _ in range(self.num_pool_layers - 1):
- up_transpose_conv.append(_TransposeConvBlock(ch))
- up_conv.append(_ConvBlock(ch))
- ch //= 2
-
- up_transpose_conv.append(_TransposeConvBlock(ch))
- up_conv.append(_ConvBlock(ch))
-
- stack = []
- output = jnp.expand_dims(x, axis=-1)
-
- # apply down-sampling layers
- for layer in down_sample_layers:
- output = layer(output, train)
- stack.append(output)
- output = nn.avg_pool(output, window_shape=(2, 2), strides=(2, 2))
-
- output = conv(output, train)
-
- # apply up-sampling layers
- for transpose_conv, conv in zip(up_transpose_conv, up_conv):
- downsample_layer = stack.pop()
- output = transpose_conv(output)
-
- # reflect pad on the right/botton if needed to handle odd input dimensions
- padding_right = 0
- padding_bottom = 0
- if output.shape[-2] != downsample_layer.shape[-2]:
- padding_right = 1 # padding right
- if output.shape[-3] != downsample_layer.shape[-3]:
- padding_bottom = 1 # padding bottom
-
- if padding_right or padding_bottom:
- padding = ((0, 0), (0, padding_bottom), (0, padding_right), (0, 0))
- output = jnp.pad(output, padding, mode='reflect')
-
- output = jnp.concatenate((output, downsample_layer), axis=-1)
- output = conv(output, train)
-
- output = nn.Conv(
- self.out_channels, kernel_size=(1, 1), strides=(1, 1))(
- output)
- return output.squeeze(-1)
-
-
-class ConvBlock(nn.Module):
- """A Convolutional Block.
- out_channels: Number of channels in the output.
- dropout_rate: Dropout probability.
- """
- out_channels: int
- dropout_rate: float
- use_tanh: bool
- use_layer_norm: bool
-
- @nn.compact
- def __call__(self, x, train=True):
- """Forward function.
- Note: Pytorch is NCHW and jax/flax is NHWC.
- Args:
- x: Input 4D tensor of shape `(N, H, W, in_channels)`.
- train: deterministic or not (use init2winit naming).
- Returns:
- jnp.array: Output tensor of shape `(N, H, W, out_channels)`.
- """
- x = nn.Conv(
- features=self.out_channels,
- kernel_size=(3, 3),
- strides=(1, 1),
- use_bias=False)(
- x)
- if self.use_layer_norm:
- x = nn.LayerNorm(reduction_axes=(1, 2, 3))(x)
- else:
- # DO NOT SUBMIT check that this comment edit is correct
- # InstanceNorm2d was run with no learnable params in reference code
- # so this is a simple normalization along spatial dims.
- x = _instance_norm2d(x, (1, 2))
- if self.use_tanh:
- activation_fn = nn.tanh
- else:
- activation_fn = functools.partial(jax.nn.leaky_relu, negative_slope=0.2)
- x = activation_fn(x)
- # Ref code uses dropout2d which applies the same mask for the entire channel
- # Replicated by using broadcast dims to have the same filter on HW
- x = nn.Dropout(
- self.dropout_rate, broadcast_dims=(1, 2), deterministic=not train)(
- x)
- x = nn.Conv(
- features=self.out_channels,
- kernel_size=(3, 3),
- strides=(1, 1),
- use_bias=False)(
- x)
- if self.use_layer_norm:
- x = nn.LayerNorm(reduction_axes=(1, 2, 3))(x)
- else:
- x = _instance_norm2d(x, (1, 2))
- x = activation_fn(x)
- x = nn.Dropout(
- self.dropout_rate, broadcast_dims=(1, 2), deterministic=not train)(
- x)
- return x
-
-
-class TransposeConvBlock(nn.Module):
- """A Transpose Convolutional Block.
- out_channels: Number of channels in the output.
- """
- out_channels: int
- use_tanh: bool
- use_layer_norm: bool
-
- @nn.compact
- def __call__(self, x):
- """Forward function.
- Args:
- x: Input 4D tensor of shape `(N, H, W, in_channels)`.
- Returns:
- jnp.array: Output tensor of shape `(N, H*2, W*2, out_channels)`.
- """
- x = nn.ConvTranspose(
- self.out_channels, kernel_size=(2, 2), strides=(2, 2), use_bias=False)(
- x)
- x = _instance_norm2d(x, (1, 2))
- if self.use_tanh:
- activation_fn = nn.tanh
- else:
- activation_fn = functools.partial(jax.nn.leaky_relu, negative_slope=0.2)
- x = activation_fn(x)
- return x
\ No newline at end of file
diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/models_ref.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/models_ref.py
deleted file mode 100644
index 357dadc13..000000000
--- a/algoperf/workloads/imagenet_resnet/imagenet_jax/models_ref.py
+++ /dev/null
@@ -1,132 +0,0 @@
-"""Jax implementation of ResNet V1.
-
-Adapted from Flax example:
-https://github.com/google/flax/blob/main/examples/imagenet/models.py.
-"""
-
-import functools
-from typing import Any, Callable, Optional, Tuple
-
-from flax import linen as nn
-import jax.numpy as jnp
-
-from algoperf import spec
-
-ModuleDef = nn.Module
-
-
-class ResNetBlock(nn.Module):
- """ResNet block."""
- filters: int
- conv: ModuleDef
- norm: ModuleDef
- act: Callable
- strides: Tuple[int, int] = (1, 1)
- bn_init_scale: float = 0.
-
- @nn.compact
- def __call__(self, x: spec.Tensor) -> spec.Tensor:
- residual = x
- y = self.conv(self.filters, (3, 3), self.strides)(x)
- y = self.norm()(y)
- y = self.act(y)
- y = self.conv(self.filters, (3, 3))(y)
- y = self.norm(scale_init=nn.initializers.constant(self.bn_init_scale))(y)
-
- if residual.shape != y.shape or self.strides != (1, 1):
- residual = self.conv(
- self.filters, (1, 1), self.strides, name='Conv_proj')(
- residual)
- residual = self.norm(name='BatchNorm_proj')(residual)
-
- return self.act(residual + y)
-
-
-class BottleneckResNetBlock(nn.Module):
- """Bottleneck ResNet block."""
- filters: int
- conv: ModuleDef
- norm: ModuleDef
- act: Callable
- strides: Tuple[int, int] = (1, 1)
- bn_init_scale: Optional[float] = None
-
- @nn.compact
- def __call__(self, x: spec.Tensor) -> spec.Tensor:
- residual = x
- y = self.conv(self.filters, (1, 1))(x)
- y = self.norm()(y)
- y = self.act(y)
- y = self.conv(self.filters, (3, 3), self.strides)(y)
- y = self.norm()(y)
- y = self.act(y)
- y = self.conv(self.filters * 4, (1, 1))(y)
- y = self.norm(scale_init=nn.initializers.constant(self.bn_init_scale))(y)
-
- if residual.shape != y.shape or self.strides != (1, 1):
- residual = self.conv(
- self.filters * 4, (1, 1), self.strides, name='Conv_proj')(
- residual)
- residual = self.norm(name='BatchNorm_proj')(residual)
-
- return self.act(residual + y)
-
-
-class ResNet(nn.Module):
- stage_sizes: Tuple[int]
- block_cls: ModuleDef
- num_classes: int
- num_filters: int = 64
- dtype: Any = jnp.float32
- act: Callable = nn.relu
- bn_init_scale: float = 0.
-
- @nn.compact
- def __call__(self,
- x: spec.Tensor,
- update_batch_norm: bool = True,
- use_running_average_bn: Optional[bool] = None) -> spec.Tensor:
- conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype)
-
- # Preserve default behavior for backwards compatibility
- if use_running_average_bn is None:
- use_running_average_bn = not update_batch_norm
- norm = functools.partial(
- nn.BatchNorm,
- use_running_average=use_running_average_bn,
- momentum=0.9,
- epsilon=1e-5,
- dtype=self.dtype)
-
- x = conv(
- self.num_filters, (7, 7), (2, 2),
- padding=[(3, 3), (3, 3)],
- name='Conv_init')(
- x)
- x = norm(name='BatchNorm_init')(x)
- x = self.act(x)
- x = nn.max_pool(x, (3, 3), strides=(2, 2), padding=((1, 1), (1, 1)))
- for i, block_size in enumerate(self.stage_sizes):
- for j in range(block_size):
- strides = (2, 2) if i > 0 and j == 0 else (1, 1)
- x = self.block_cls(
- self.num_filters * 2**i,
- strides=strides,
- conv=conv,
- norm=norm,
- act=self.act,
- bn_init_scale=self.bn_init_scale)(
- x)
- x = jnp.mean(x, axis=(1, 2))
- x = nn.Dense(
- self.num_classes,
- kernel_init=nn.initializers.normal(),
- dtype=self.dtype)(
- x)
- return x
-
-
-ResNet18 = functools.partial(
- ResNet, stage_sizes=(2, 2, 2, 2), block_cls=ResNetBlock)
-ResNet50 = functools.partial(
- ResNet, stage_sizes=(3, 4, 6, 3), block_cls=BottleneckResNetBlock)
\ No newline at end of file
diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/models_ref.py b/algoperf/workloads/imagenet_vit/imagenet_jax/models_ref.py
deleted file mode 100644
index beb8a2eb8..000000000
--- a/algoperf/workloads/imagenet_vit/imagenet_jax/models_ref.py
+++ /dev/null
@@ -1,235 +0,0 @@
-"""Jax implementation of refactored and simplified ViT.
-
-Forked from:
-https://github.com/google/init2winit/blob/master/init2winit/model_lib/vit.py,
-originally from https://github.com/google/big_vision with modifications noted.
-"""
-
-from typing import Optional, Sequence, Union
-
-from flax import linen as nn
-import jax.numpy as jnp
-
-from algoperf import spec
-
-
-def posemb_sincos_2d(h: int,
- w: int,
- width: int,
- temperature: int = 10_000.,
- dtype: jnp.dtype = jnp.float32) -> spec.Tensor:
- """Follows the MoCo v3 logic."""
- y, x = jnp.mgrid[:h, :w] #pylint: disable=unpacking-non-sequence
-
- if width % 4 != 0:
- raise ValueError('Width must be mult of 4 for sincos posemb.')
- omega = jnp.arange(width // 4) / (width // 4 - 1)
- omega = 1. / (temperature**omega)
- y = jnp.einsum('m,d->md', y.flatten(), omega)
- x = jnp.einsum('m,d->md', x.flatten(), omega)
- pe = jnp.concatenate([jnp.sin(x), jnp.cos(x), jnp.sin(y), jnp.cos(y)], axis=1)
- return jnp.asarray(pe, dtype)[None, :, :]
-
-
-class MlpBlock(nn.Module):
- """Transformer MLP / feed-forward block."""
- mlp_dim: Optional[int] = None # Defaults to 4x input dim.
- use_glu: bool = False
- dropout_rate: float = 0.0
-
- @nn.compact
- def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor:
- """Applies Transformer MlpBlock module."""
- inits = {
- 'kernel_init': nn.initializers.xavier_uniform(),
- 'bias_init': nn.initializers.normal(stddev=1e-6),
- }
-
- d = x.shape[2]
- x = nn.Dense(self.mlp_dim or 4 * d, **inits)(x)
- x = nn.gelu(x)
-
- if self.use_glu:
- y = nn.Dense(self.mlp_dim, **inits)(x)
- x = x * y
-
- x = nn.Dropout(rate=self.dropout_rate)(x, train)
- x = nn.Dense(d, **inits)(x)
- return x
-
-
-class Encoder1DBlock(nn.Module):
- """Single transformer encoder block (MHSA + MLP)."""
- mlp_dim: Optional[int] = None # Defaults to 4x input dim.
- num_heads: int = 12
- use_glu: bool = False
- use_post_layer_norm: bool = False
- dropout_rate: float = 0.0
-
- @nn.compact
- def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor:
- if not self.use_post_layer_norm:
- y = nn.LayerNorm(name='LayerNorm_0')(x)
- y = nn.MultiHeadDotProductAttention(
- num_heads=self.num_heads,
- kernel_init=nn.initializers.xavier_uniform(),
- deterministic=train,
- name='MultiHeadDotProductAttention_1')(
- y)
- y = nn.Dropout(rate=self.dropout_rate)(y, train)
- x = x + y
-
- y = nn.LayerNorm(name='LayerNorm_2')(x)
- y = MlpBlock(
- mlp_dim=self.mlp_dim,
- use_glu=self.use_glu,
- dropout_rate=self.dropout_rate,
- name='MlpBlock_3')(y, train)
- y = nn.Dropout(rate=self.dropout_rate)(y, train)
- x = x + y
- else:
- y = x
- y = nn.MultiHeadDotProductAttention(
- num_heads=self.num_heads,
- kernel_init=nn.initializers.xavier_uniform(),
- deterministic=train,
- name='MultiHeadDotProductAttention_1')(
- y)
- y = nn.Dropout(rate=self.dropout_rate)(y, train)
- x = x + y
- x = nn.LayerNorm(name='LayerNorm_0')(x)
-
- y = x
- y = MlpBlock(
- mlp_dim=self.mlp_dim,
- use_glu=self.use_glu,
- dropout_rate=self.dropout_rate,
- name='MlpBlock_3')(y, train)
- y = nn.Dropout(rate=self.dropout_rate)(y, train)
- x = x + y
- x = nn.LayerNorm(name='LayerNorm_2')(x)
-
- return x
-
-
-class Encoder(nn.Module):
- """Transformer Model Encoder for sequence to sequence translation."""
- depth: int
- mlp_dim: Optional[int] = None # Defaults to 4x input dim.
- num_heads: int = 12
- dropout_rate: float = 0.0
- use_glu: bool = False
- use_post_layer_norm: bool = False
-
- @nn.compact
- def __call__(self, x: spec.Tensor, train: bool = True) -> spec.Tensor:
- # Input Encoder
- for lyr in range(self.depth):
- block = Encoder1DBlock(
- name=f'encoderblock_{lyr}',
- mlp_dim=self.mlp_dim,
- num_heads=self.num_heads,
- use_glu=self.use_glu,
- use_post_layer_norm=self.use_post_layer_norm,
- dropout_rate=self.dropout_rate)
- x = block(x, train)
- if not self.use_post_layer_norm:
- return nn.LayerNorm(name='encoder_layernorm')(x)
- else:
- return x
-
-
-class MAPHead(nn.Module):
- """Multihead Attention Pooling."""
- mlp_dim: Optional[int] = None # Defaults to 4x input dim
- num_heads: int = 12
-
- @nn.compact
- def __call__(self, x):
- n, _, d = x.shape
- probe = self.param('probe',
- nn.initializers.xavier_uniform(), (1, 1, d),
- x.dtype)
- probe = jnp.tile(probe, [n, 1, 1])
-
- x = nn.MultiHeadDotProductAttention(
- num_heads=self.num_heads,
- use_bias=True,
- kernel_init=nn.initializers.xavier_uniform())(probe, x)
-
- y = nn.LayerNorm()(x)
- x = x + MlpBlock(mlp_dim=self.mlp_dim)(y)
- return x[:, 0]
-
-
-class ViT(nn.Module):
- """ViT model."""
-
- num_classes: int = 1000
- patch_size: Sequence[int] = (16, 16)
- width: int = 768
- depth: int = 12
- mlp_dim: Optional[int] = None # Defaults to 4x input dim.
- num_heads: int = 12
- rep_size: Union[int, bool] = True
- dropout_rate: Optional[float] = 0.0 # If None, defaults to 0.0.
- reinit: Optional[Sequence[str]] = None
- head_zeroinit: bool = True
- use_glu: bool = False
- use_post_layer_norm: bool = False
- use_map: bool = False
-
- def get_posemb(self,
- seqshape: tuple,
- width: int,
- dtype: jnp.dtype = jnp.float32) -> spec.Tensor:
- return posemb_sincos_2d(*seqshape, width, dtype=dtype)
-
- @nn.compact
- def __call__(self, x: spec.Tensor, *, train: bool = False) -> spec.Tensor:
- # Patch extraction
- x = nn.Conv(
- self.width,
- self.patch_size,
- strides=self.patch_size,
- padding='VALID',
- name='conv_patch_extract')(
- x)
-
- n, h, w, c = x.shape
- x = jnp.reshape(x, [n, h * w, c])
-
- # Add posemb before adding extra token.
- x = x + self.get_posemb((h, w), c, x.dtype)
-
- dropout_rate = self.dropout_rate
- if dropout_rate is None:
- dropout_rate = 0.0
- x = nn.Dropout(rate=dropout_rate)(x, not train)
-
- x = Encoder(
- depth=self.depth,
- mlp_dim=self.mlp_dim,
- num_heads=self.num_heads,
- use_glu=self.use_glu,
- use_post_layer_norm=self.use_post_layer_norm,
- dropout_rate=dropout_rate,
- name='Transformer')(
- x, train=not train)
-
- if self.use_map:
- x = MAPHead(num_heads=self.num_heads, mlp_dim=self.mlp_dim)(x)
- else:
- x = jnp.mean(x, axis=1)
-
- if self.rep_size:
- rep_size = self.width if self.rep_size is True else self.rep_size
- hid = nn.Dense(rep_size, name='pre_logits')
- x = nn.tanh(hid(x))
-
- if self.num_classes:
- kw = {'kernel_init': nn.initializers.zeros} if self.head_zeroinit else {}
- head = nn.Dense(self.num_classes, name='head', **kw)
- x = head(x)
-
- return x
\ No newline at end of file
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/models_ref.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/models_ref.py
deleted file mode 100644
index f168833d3..000000000
--- a/algoperf/workloads/librispeech_conformer/librispeech_jax/models_ref.py
+++ /dev/null
@@ -1,712 +0,0 @@
-r"""Conformer.
-
-This model uses a conformer network to convert speech to text.
-paper : https://arxiv.org/abs/2005.08100
-
-high-level overview of Conformer encoder layer.
-
- x = x + 0.5 * FeedForward(x)
- x = x + MHSA(x)
- x = x + ConvolutionBlock(x)
- x = x + 0.5 * FeedForward(x)
- y = layer_norm(x)
-"""
-
-import functools
-import math
-from typing import Any, List, Optional
-
-from flax import linen as nn
-from flax import struct
-import jax
-import jax.numpy as jnp
-import numpy as np
-
-from algoperf.workloads.librispeech_conformer.librispeech_jax import \
- librispeech_preprocessor as preprocessor
-from algoperf.workloads.librispeech_conformer.librispeech_jax import \
- spectrum_augmenter
-
-
-@struct.dataclass
-class ConformerConfig:
- """Global hyperparameters used to minimize obnoxious kwarg plumbing."""
- vocab_size: int = 1024
- dtype: Any = jnp.float32
- encoder_dim: int = 512
- num_attention_heads: int = 8
- num_encoder_layers: int = 4
- attention_dropout_rate: float = 0.1
- # If None, defaults to 0.1.
- attention_residual_dropout_rate: Optional[float] = 0.1
- # If None, defaults to 0.0.
- conv_residual_dropout_rate: Optional[float] = 0.1
- feed_forward_dropout_rate: float = 0.1
- # If None, defaults to 0.1.
- feed_forward_residual_dropout_rate: Optional[float] = 0.1
- convolution_kernel_size: int = 5
- feed_forward_expansion_factor: int = 4
- freq_mask_count: int = 2
- freq_mask_max_bins: int = 27
- time_mask_count: int = 10
- time_mask_max_frames: int = 40
- time_mask_max_ratio: float = 0.05
- time_masks_per_frame: float = 0.0
- use_dynamic_time_mask_max_frames: bool = True
- # If None, defaults to 0.1.
- input_dropout_rate: Optional[float] = 0.1
- batch_norm_momentum: float = 0.999
- batch_norm_epsilon: float = 0.001
- use_specaug: bool = True
- attention_temperature: float = 1.0
- activation_function_name: str = 'swish'
- use_post_layer_norm: bool = True
-
-
-class LayerNorm(nn.Module):
- """Module implementing layer normalization.
-
- This implementation is same as in this paper:
- https://arxiv.org/pdf/1607.06450.pdf.
-
- note: we multiply normalized inputs by (1 + scale) and initialize scale to
- zeros, this differs from default flax implementation of multiplying by scale
- and initializing to ones.
- """
- dim: int = 0
- epsilon: float = 1e-6
-
- def setup(self):
- self.scale = self.param('scale', nn.initializers.zeros, [self.dim])
- self.bias = self.param('bias', nn.initializers.zeros, [self.dim])
-
- @nn.compact
- def __call__(self, inputs):
- mean = jnp.mean(inputs, axis=[-1], keepdims=True)
- var = jnp.mean(jnp.square(inputs - mean), axis=[-1], keepdims=True)
-
- normed_inputs = (inputs - mean) * jax.lax.rsqrt(var + self.epsilon)
- normed_inputs *= (1 + self.scale)
- normed_inputs += self.bias
-
- return normed_inputs
-
-
-class Subsample(nn.Module):
- """Module to perform strided convolution in order to subsample inputs.
-
- Attributes:
- encoder_dim: model dimension of conformer.
- input_dropout_rate: dropout rate for inputs.
- """
- encoder_dim: int = 0
- input_dropout_rate: float = 0.0
-
- @nn.compact
- def __call__(self, inputs, input_paddings, train):
- output_paddings = input_paddings
- outputs = jnp.expand_dims(inputs, axis=-1)
-
- outputs, output_paddings = Conv2dSubsampling(
- input_channels=1, output_channels=self.encoder_dim)(
- outputs, output_paddings)
-
- outputs, output_paddings = Conv2dSubsampling(
- input_channels=self.encoder_dim,
- output_channels=self.encoder_dim)(outputs, output_paddings)
-
- batch_size, subsampled_lengths, subsampled_dims, channels = outputs.shape
-
- outputs = jnp.reshape(
- outputs, (batch_size, subsampled_lengths, subsampled_dims * channels))
-
- outputs = nn.Dense(
- self.encoder_dim,
- use_bias=True,
- kernel_init=nn.initializers.xavier_uniform())(
- outputs)
-
- outputs = outputs + AddPositionalEmbedding(embedding_dim=self.encoder_dim)(
- seq_length=outputs.shape[1])
-
- outputs = nn.Dropout(
- rate=self.input_dropout_rate, deterministic=not train)(
- outputs)
-
- return outputs, output_paddings
-
-
-class Conv2dSubsampling(nn.Module):
- """Helper module used in Subsample layer.
-
- 1) Performs strided convolution over inputs and then applies non-linearity.
- 2) Also performs strided convolution over input_paddings to return the correct
- paddings for downstream layers.
- """
- input_channels: int = 0
- output_channels: int = 0
- filter_stride: List[int] = (2, 2)
- padding: str = 'SAME'
-
- def setup(self):
- self.filter_shape = (3, 3, self.input_channels, self.output_channels)
- self.kernel = self.param('kernel',
- nn.initializers.xavier_uniform(),
- self.filter_shape)
- self.bias = self.param(
- 'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels)
-
- @nn.compact
- def __call__(self, inputs, paddings):
- # Computing strided convolution to subsample inputs.
- feature_group_count = inputs.shape[3] // self.filter_shape[2]
- outputs = jax.lax.conv_general_dilated(
- lhs=inputs,
- rhs=self.kernel,
- window_strides=self.filter_stride,
- padding=self.padding,
- rhs_dilation=(1, 1),
- dimension_numbers=('NHWC', 'HWIO', 'NHWC'),
- feature_group_count=feature_group_count)
-
- outputs += jnp.reshape(self.bias, (1,) * (outputs.ndim - 1) + (-1,))
- outputs = nn.relu(outputs)
-
- # Computing correct paddings post input convolution.
- input_length = paddings.shape[1]
- stride = self.filter_stride[0]
-
- pad_len = (input_length + stride - 1) // stride * stride - input_length
- out_padding = jax.lax.conv_general_dilated(
- lhs=paddings[:, :, None],
- rhs=jnp.ones([1, 1, 1]),
- window_strides=self.filter_stride[:1],
- padding=[(0, pad_len)],
- dimension_numbers=('NHC', 'HIO', 'NHC'))
- out_padding = jnp.squeeze(out_padding, axis=-1)
-
- # Mask outputs by correct paddings to ensure padded elements in inputs map
- # to padded value in outputs.
- outputs = outputs * \
- (1.0 - jnp.expand_dims(jnp.expand_dims(out_padding, -1), -1))
- return outputs, out_padding
-
-
-class FeedForwardModule(nn.Module):
- """Feedforward block of conformer layer.
- """
- config: ConformerConfig
-
- @nn.compact
- def __call__(self, inputs, padding_mask=None, train=False):
- config = self.config
-
- inputs = LayerNorm(dim=config.encoder_dim)(inputs)
-
- inputs = nn.Dense(
- config.encoder_dim * config.feed_forward_expansion_factor,
- use_bias=True,
- kernel_init=nn.initializers.xavier_uniform())(
- inputs)
- if config.activation_function_name == 'swish':
- activation_fn = nn.swish
- elif config.activation_function_name == 'gelu':
- activation_fn = nn.gelu
- else:
- raise ValueError('Only "swish" and "gelu" are supported '
- 'config.activation_function_name values, recieved '
- f'{config.activation_function_name}')
- inputs = activation_fn(inputs)
- inputs = nn.Dropout(rate=config.feed_forward_dropout_rate)(
- inputs, deterministic=not train)
-
- inputs = inputs * padding_mask
-
- inputs = nn.Dense(
- config.encoder_dim,
- use_bias=True,
- kernel_init=nn.initializers.xavier_uniform())(
- inputs)
- inputs = inputs * padding_mask
-
- if config.feed_forward_residual_dropout_rate is None:
- feed_forward_residual_dropout_rate = 0.1
- else:
- feed_forward_residual_dropout_rate = (
- config.feed_forward_residual_dropout_rate)
- inputs = nn.Dropout(rate=feed_forward_residual_dropout_rate)(
- inputs, deterministic=not train)
-
- return inputs
-
-
-class AddPositionalEmbedding(nn.Module):
- """Adds (optionally learned) positional embeddings to the inputs.
-
- Attributes:
- max_len: maximum possible length for the input
- posemb_init: positional embedding initializer
- """
- min_timescale: int = 1
- max_timescale: int = 10_000
- embedding_dim: int = 512
-
- @nn.compact
- def __call__(self, seq_length):
- position = jnp.arange(seq_length, dtype=jnp.float32)[jnp.newaxis, :]
- num_timescales = self.embedding_dim // 2
- log_timescale_increment = (
- math.log(float(self.max_timescale) / float(self.min_timescale)) /
- jnp.maximum(jnp.asarray(num_timescales, dtype=jnp.float32) - 1, 1))
- inv_timescales = self.min_timescale * jnp.exp(
- jnp.arange(num_timescales, dtype=jnp.float32) *
- -log_timescale_increment)
- scaled_time = (
- position[:, :, jnp.newaxis] *
- inv_timescales[jnp.newaxis, jnp.newaxis, :])
- signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)],
- axis=2).astype(jnp.float32)
- # Force usage of `np` rather than `jnp` to compute static values at trace
- # time.
- signal = jnp.pad(signal,
- [[0, 0], [0, 0], [0, np.mod(self.embedding_dim, 2)]])
- return signal
-
-
-# Adapted from lingvo attention layer for query scaling
-# https://github.com/tensorflow/lingvo/blob/7de4ca8fff3cb28c2ecb21bbd7b02a964ce727f7/lingvo/jax/layers/attentions.py#L201
-class QueryScaler(nn.Module):
- """A layer to scale individual dims of the query attention matrix."""
- dim: int = 0
-
- def setup(self):
- self.scale = self.param('scale', nn.initializers.zeros, [self.dim])
-
- @nn.compact
- def __call__(self, inputs):
- inputs_shape = inputs.shape
- if inputs_shape[-1] != self.dim:
- raise ValueError('QueryScaler expects inputs to have'
- ' same last dimension as scaling param.')
-
- # 1.0/jax.nn.softplus(0.0) = 1.442695041. Hard code this number so that we
- # can avoid unnecessary XLA op fusion mess on TPU.
- r_softplus_0 = 1.442695041
-
- scale = jnp.array(r_softplus_0, dtype=inputs.dtype)
- scale *= jax.nn.softplus(self.scale)
-
- return inputs * scale
-
-
-# Modifying flax linen default dot product attention function to add
-# query scaling, reference to original function here :
-# https://github.com/google/flax/blob/a9af38085a7a49b571cf37d375060fd683e74972/flax/linen/attention.py#L121
-def dot_product_attention(query,
- key,
- value,
- bias=None,
- mask=None,
- broadcast_dropout=True,
- dropout_rng=None,
- dropout_rate=0.,
- deterministic=False,
- dtype=jnp.float32,
- precision=None,
- temperature=1.0):
- """Computes dot-product attention given query, key, and value.
-
- This is the core function for applying attention based on
- https://arxiv.org/abs/1706.03762. It's slightly modified to add query scaling.
- It calculates the attention weights given query and key and combines the
- values using the attention weights.
-
- Note: query, key, value needn't have any batch dimensions.
-
- Args:
- query: queries for calculating attention with shape of
- `[batch..., q_length, num_heads, qk_depth_per_head]`.
- key: keys for calculating attention with shape of
- `[batch..., kv_length, num_heads, qk_depth_per_head]`.
- value: values to be used in attention with shape of
- `[batch..., kv_length, num_heads, v_depth_per_head]`.
- bias: bias for the attention weights. This should be broadcastable to the
- shape `[batch..., num_heads, q_length, kv_length]`.
- This can be used for incorporating causal masks, padding masks,
- proximity bias, etc.
- mask: mask for the attention weights. This should be broadcastable to the
- shape `[batch..., num_heads, q_length, kv_length]`.
- This can be used for incorporating causal masks.
- Attention weights are masked out if their corresponding mask value
- is `False`.
- broadcast_dropout: bool: use a broadcasted dropout along batch dims.
- dropout_rng: JAX PRNGKey: to be used for dropout
- dropout_rate: dropout rate
- deterministic: bool, deterministic or not (to apply dropout)
- dtype: the dtype of the computation (default: float32)
- precision: numerical precision of the computation see `jax.lax.Precision`
- for details.
-
- Returns:
- Output of shape `[batch..., q_length, num_heads, v_depth_per_head]`.
- """
- assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.'
- assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], (
- 'q, k, v batch dims must match.')
- assert query.shape[-2] == key.shape[-2] == value.shape[-2], (
- 'q, k, v num_heads must match.')
- assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.'
-
- # compute attention weights
- query = QueryScaler(dim=query.shape[-1])(query)
- attn_weights = nn.attention.dot_product_attention_weights(
- query,
- key,
- bias,
- mask,
- broadcast_dropout,
- dropout_rng,
- dropout_rate,
- deterministic,
- dtype,
- precision)
-
- # return weighted sum over values for each query position
- return jnp.einsum(
- '...hqk,...khd->...qhd', attn_weights, value,
- precision=precision) * temperature
-
-
-class MultiHeadedSelfAttention(nn.Module):
- """Self attention sub-layer used in the Conformer layer.
-
- Input is first normalized using layer norm. Output is processed using
- multi-headed attention.
-
- Note: this attention implementation uses a learned scale parameter to scale
- query matrix before passing it to flax attention module.
- """
- config: ConformerConfig = None
-
- @nn.compact
- def __call__(self, inputs, paddings, train):
- config = self.config
- mask_paddings = 1 - paddings
- attention_mask = nn.make_attention_mask(
- mask_paddings > 0, mask_paddings > 0, dtype=jnp.float32)
-
- inputs = LayerNorm(dim=config.encoder_dim)(inputs)
- attention_fn = functools.partial(
- dot_product_attention, temperature=config.attention_temperature)
- result = nn.MultiHeadDotProductAttention(
- num_heads=config.num_attention_heads,
- qkv_features=config.encoder_dim,
- decode=False,
- dtype=config.dtype,
- kernel_init=nn.initializers.xavier_uniform(),
- bias_init=nn.initializers.zeros,
- use_bias=True,
- broadcast_dropout=False,
- attention_fn=attention_fn,
- dropout_rate=config.attention_dropout_rate,
- deterministic=not train)(
- inputs_q=inputs, mask=attention_mask)
-
- if config.attention_residual_dropout_rate is None:
- attention_residual_dropout_rate = 0.1
- else:
- attention_residual_dropout_rate = config.attention_residual_dropout_rate
- result = nn.Dropout(
- rate=attention_residual_dropout_rate, deterministic=not train)(
- result)
-
- return result
-
-
-class BatchNorm(nn.Module):
- """Implements batch norm respecting input paddings.
-
- This implementation takes into account input padding by masking inputs before
- computing mean and variance.
-
- This is inspired by lingvo jax implementation of BatchNorm:
- https://github.com/tensorflow/lingvo/blob/84b85514d7ad3652bc9720cb45acfab08604519b/lingvo/jax/layers/normalizations.py#L92
-
- and the corresponding defaults for momentum and epsilon have been copied over
- from lingvo.
- """
- config: ConformerConfig
-
- def setup(self):
- dim = self.config.encoder_dim
- dtype = self.config.dtype
-
- self.ra_mean = self.variable('batch_stats',
- 'mean',
- lambda s: jnp.zeros(s, dtype),
- dim)
- self.ra_var = self.variable('batch_stats',
- 'var',
- lambda s: jnp.ones(s, dtype),
- dim)
-
- self.gamma = self.param('scale', nn.initializers.zeros, dim, dtype)
- self.beta = self.param('bias', nn.initializers.zeros, dim, dtype)
-
- @nn.compact
- def __call__(self,
- inputs,
- input_paddings,
- update_batch_norm,
- use_running_average_bn):
- rank = inputs.ndim
- reduce_over_dims = list(range(0, rank - 1))
-
- padding = jnp.expand_dims(input_paddings, -1)
- momentum = self.config.batch_norm_momentum
- epsilon = self.config.batch_norm_epsilon
-
- if use_running_average_bn:
- mean = self.ra_mean.value
- var = self.ra_var.value
-
- else:
- # compute batch statistics
- mask = 1.0 - padding
- sum_v = jnp.sum(inputs * mask, axis=reduce_over_dims, keepdims=True)
- count_v = jnp.sum(
- jnp.ones_like(inputs) * mask, axis=reduce_over_dims, keepdims=True)
-
- count_v = jnp.maximum(count_v, 1.0)
- mean = sum_v / count_v
-
- sum_vv = jnp.sum(
- (inputs - mean) * (inputs - mean) * mask,
- axis=reduce_over_dims,
- keepdims=True)
-
- var = sum_vv / count_v
-
- if update_batch_norm:
- self.ra_mean.value = momentum * \
- self.ra_mean.value + (1 - momentum) * mean
- self.ra_var.value = momentum * \
- self.ra_var.value + (1 - momentum) * var
-
- inv = (1 + self.gamma) / jnp.sqrt(var + epsilon)
- bn_output = (inputs - mean) * inv + self.beta
- bn_output *= 1.0 - padding
-
- return bn_output
-
-
-class ConvolutionBlock(nn.Module):
- r"""Convolution block in conformer layer.
-
- architecture:
-
- input # (batch, time, hidden_dim)
- |
- layer_norm(.) # (batch, time, hidden_dim)
- dense(.), dense(.) # (batch, time, 2 * hidden_dim)
- | /
- glu(.) # (batch, time, hidden_dim)
- depthwise_conv1d(.)
- batch_norm(.)
- act(.)
- |
- dense(.)
- dropout(.)
- |
- output
- """
- config: ConformerConfig
-
- @nn.compact
- def __call__(self,
- inputs,
- input_paddings,
- train,
- update_batch_norm,
- use_running_average_bn):
- config = self.config
- inputs = LayerNorm(dim=config.encoder_dim)(inputs)
-
- input_gated1 = nn.Dense(
- config.encoder_dim,
- kernel_init=nn.initializers.xavier_uniform(),
- use_bias=True)(
- inputs)
-
- input_gated2 = nn.Dense(
- config.encoder_dim,
- kernel_init=nn.initializers.xavier_uniform(),
- use_bias=True)(
- inputs)
-
- inputs = input_gated1 * jax.nn.sigmoid(input_gated2)
- inputs = inputs * (1 - jnp.expand_dims(input_paddings, -1))
-
- inputs = nn.Conv(
- features=config.encoder_dim,
- kernel_size=(config.convolution_kernel_size,),
- strides=(1,),
- padding='SAME',
- feature_group_count=config.encoder_dim,
- use_bias=False,
- kernel_init=nn.initializers.xavier_uniform())(
- inputs)
-
- inputs = BatchNorm(config)(inputs,
- input_paddings,
- update_batch_norm,
- use_running_average_bn)
- if config.activation_function_name == 'swish':
- activation_fn = nn.swish
- elif config.activation_function_name == 'gelu':
- activation_fn = nn.gelu
- else:
- raise ValueError('Only "swish" and "gelu" are supported '
- 'config.activation_function_name values, recieved '
- f'{config.activation_function_name}')
- inputs = activation_fn(inputs)
- inputs = nn.Dense(
- config.encoder_dim, kernel_init=nn.initializers.xavier_uniform())(
- inputs)
-
- if config.conv_residual_dropout_rate is None:
- conv_residual_dropout_rate = 0.0
- else:
- conv_residual_dropout_rate = config.conv_residual_dropout_rate
- inputs = nn.Dropout(
- rate=conv_residual_dropout_rate, deterministic=not train)(
- inputs)
- return inputs
-
-
-class ConformerBlock(nn.Module):
- """Implements a single conformer encoder layer.
-
- High level overview:
-
- x = x + 0.5 * FeedForward(x)
- x = x + MHSA(x)
- x = x + ConvolutionBlock(x)
- x = x + 0.5 * FeedForward(x)
-
- y = layer_norm(x)
-
- """
- config: ConformerConfig
-
- @nn.compact
- def __call__(self,
- inputs,
- input_paddings,
- train,
- update_batch_norm,
- use_running_average):
- config = self.config
- padding_mask = jnp.expand_dims(1 - input_paddings, -1)
-
- inputs = inputs + 0.5 * FeedForwardModule(config=self.config)(
- inputs, padding_mask, train)
-
- inputs = inputs + MultiHeadedSelfAttention(config=self.config)(
- inputs, input_paddings, train)
-
- inputs = inputs + \
- ConvolutionBlock(config)(inputs,
- input_paddings,
- train,
- update_batch_norm,
- use_running_average
- )
-
- inputs = inputs + 0.5 * FeedForwardModule(config=self.config)(
- inputs, padding_mask, train)
-
- if config.use_post_layer_norm:
- inputs = LayerNorm(dim=config.encoder_dim)(inputs)
-
- return inputs
-
-
-class Conformer(nn.Module):
- """Conformer (encoder + decoder) block.
-
- Takes audio input signals and outputs probability distribution over vocab size
- for each time step. The output is then fed into a CTC loss which eliminates
- the need for alignment with targets.
- """
- config: ConformerConfig
-
- def setup(self):
- self.specaug = spectrum_augmenter.SpecAug(
- freq_mask_count=self.config.freq_mask_count,
- freq_mask_max_bins=self.config.freq_mask_max_bins,
- time_mask_count=self.config.time_mask_count,
- time_mask_max_frames=self.config.time_mask_max_frames,
- time_mask_max_ratio=self.config.time_mask_max_ratio,
- time_masks_per_frame=self.config.time_masks_per_frame,
- use_dynamic_time_mask_max_frames=self.config
- .use_dynamic_time_mask_max_frames)
-
- @nn.compact
- def __call__(self,
- inputs,
- input_paddings,
- train,
- update_batch_norm: Optional[bool] = None,
- use_running_average_bn: Optional[bool] = None):
- config = self.config
-
- outputs = inputs
- output_paddings = input_paddings
-
- # Set BN args if not supplied for backwards compatibility
- if update_batch_norm is None:
- update_batch_norm = train
- if use_running_average_bn is None:
- use_running_average_bn = not train
-
- # Compute normalized log mel spectrograms from input audio signal.
- preprocessing_config = preprocessor.LibrispeechPreprocessingConfig()
- outputs, output_paddings = preprocessor.MelFilterbankFrontend(
- preprocessing_config,
- per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR,
- per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR)(
- outputs, output_paddings)
-
- # Ablate random parts of input along temporal and frequency dimension
- # following the specaug procedure in https://arxiv.org/abs/1904.08779.
- if train and config.use_specaug:
- outputs, output_paddings = self.specaug(outputs, output_paddings)
-
- # Subsample input by a factor of 4 by performing strided convolutions.
- if config.input_dropout_rate is None:
- input_dropout_rate = 0.1
- else:
- input_dropout_rate = config.input_dropout_rate
- outputs, output_paddings = Subsample(
- encoder_dim=config.encoder_dim,
- input_dropout_rate=input_dropout_rate)(
- outputs, output_paddings, train)
-
- # Run the conformer encoder layers.
- for _ in range(config.num_encoder_layers):
- outputs = ConformerBlock(config)(outputs,
- output_paddings,
- train,
- update_batch_norm,
- use_running_average_bn)
-
- outputs = LayerNorm(config.encoder_dim)(outputs)
- # Run the decoder which in this case is a trivial projection layer.
- outputs = nn.Dense(
- config.vocab_size,
- use_bias=True,
- kernel_init=nn.initializers.xavier_uniform())(
- outputs)
-
- return outputs, output_paddings
\ No newline at end of file
diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models_ref.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models_ref.py
deleted file mode 100644
index 7b7c9720a..000000000
--- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models_ref.py
+++ /dev/null
@@ -1,525 +0,0 @@
-r"""Deepspeech.
-
-This model uses a deepspeech2 network to convert speech to text.
-paper : https://arxiv.org/abs/1512.02595
-
-# BiLSTM code contributed by bastings@
-# github : https://github.com/bastings
-# webpage : https://bastings.github.io/
-"""
-
-from typing import Any, Dict, List, Optional, Tuple, Union
-
-from flax import linen as nn
-from flax import struct
-import jax
-from jax.experimental import rnn
-import jax.numpy as jnp
-
-from algoperf.workloads.librispeech_conformer.librispeech_jax import \
- librispeech_preprocessor as preprocessor
-from algoperf.workloads.librispeech_conformer.librispeech_jax import \
- spectrum_augmenter
-
-Array = jnp.ndarray
-StateType = Union[Array, Tuple[Array, ...]]
-PRNGKey = Any
-Shape = Tuple[int]
-Dtype = Any
-Carry = Any
-CarryHistory = Any
-Output = Any
-
-
-@struct.dataclass
-class DeepspeechConfig:
- """Global hyperparameters used to minimize obnoxious kwarg plumbing."""
- vocab_size: int = 1024
- dtype: Any = jnp.float32
- encoder_dim: int = 512
- num_lstm_layers: int = 6
- num_ffn_layers: int = 3
- conv_subsampling_factor: int = 2
- conv_subsampling_layers: int = 2
- use_specaug: bool = True
- freq_mask_count: int = 2
- freq_mask_max_bins: int = 27
- time_mask_count: int = 10
- time_mask_max_frames: int = 40
- time_mask_max_ratio: float = 0.05
- time_masks_per_frame: float = 0.0
- use_dynamic_time_mask_max_frames: bool = True
- batch_norm_momentum: float = 0.999
- batch_norm_epsilon: float = 0.001
- # If None, defaults to 0.1.
- input_dropout_rate: Optional[float] = 0.1
- # If None, defaults to 0.1.
- feed_forward_dropout_rate: Optional[float] = 0.1
- enable_residual_connections: bool = True
- enable_decoder_layer_norm: bool = True
- bidirectional: bool = True
- use_tanh: bool = False
- layernorm_everywhere: bool = False
-
-
-class Subsample(nn.Module):
- """Module to perform strided convolution in order to subsample inputs.
-
- Attributes:
- encoder_dim: model dimension of conformer.
- input_dropout_rate: dropout rate for inputs.
- """
- config: DeepspeechConfig
-
- @nn.compact
- def __call__(self, inputs, output_paddings, train):
- config = self.config
- outputs = jnp.expand_dims(inputs, axis=-1)
-
- outputs, output_paddings = Conv2dSubsampling(
- encoder_dim=config.encoder_dim,
- dtype=config.dtype,
- batch_norm_momentum=config.batch_norm_momentum,
- batch_norm_epsilon=config.batch_norm_epsilon,
- input_channels=1,
- output_channels=config.encoder_dim,
- use_tanh=config.use_tanh
- )(outputs, output_paddings, train)
-
- outputs, output_paddings = Conv2dSubsampling(
- encoder_dim=config.encoder_dim,
- dtype=config.dtype,
- batch_norm_momentum=config.batch_norm_momentum,
- batch_norm_epsilon=config.batch_norm_epsilon,
- input_channels=config.encoder_dim,
- output_channels=config.encoder_dim,
- use_tanh=config.use_tanh)(outputs, output_paddings, train)
-
- batch_size, subsampled_lengths, subsampled_dims, channels = outputs.shape
-
- outputs = jnp.reshape(
- outputs, (batch_size, subsampled_lengths, subsampled_dims * channels))
-
- outputs = nn.Dense(
- config.encoder_dim,
- use_bias=True,
- kernel_init=nn.initializers.xavier_uniform())(
- outputs)
-
- if config.input_dropout_rate is None:
- input_dropout_rate = 0.1
- else:
- input_dropout_rate = config.input_dropout_rate
- outputs = nn.Dropout(
- rate=input_dropout_rate, deterministic=not train)(
- outputs)
-
- return outputs, output_paddings
-
-
-class Conv2dSubsampling(nn.Module):
- """Helper module used in Subsample layer.
-
- 1) Performs strided convolution over inputs and then applies non-linearity.
- 2) Also performs strided convolution over input_paddings to return the correct
- paddings for downstream layers.
- """
- input_channels: int = 0
- output_channels: int = 0
- filter_stride: List[int] = (2, 2)
- padding: str = 'SAME'
- encoder_dim: int = 0
- dtype: Any = jnp.float32
- batch_norm_momentum: float = 0.999
- batch_norm_epsilon: float = 0.001
- use_tanh: bool = False
-
- def setup(self):
- self.filter_shape = (3, 3, self.input_channels, self.output_channels)
- self.kernel = self.param('kernel',
- nn.initializers.xavier_uniform(),
- self.filter_shape)
- self.bias = self.param(
- 'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels)
-
- @nn.compact
- def __call__(self, inputs, paddings, train):
- # Computing strided convolution to subsample inputs.
- feature_group_count = inputs.shape[3] // self.filter_shape[2]
- outputs = jax.lax.conv_general_dilated(
- lhs=inputs,
- rhs=self.kernel,
- window_strides=self.filter_stride,
- padding=self.padding,
- rhs_dilation=(1, 1),
- dimension_numbers=('NHWC', 'HWIO', 'NHWC'),
- feature_group_count=feature_group_count)
-
- outputs += jnp.reshape(self.bias, (1,) * (outputs.ndim - 1) + (-1,))
-
- if self.use_tanh:
- outputs = nn.tanh(outputs)
- else:
- outputs = nn.relu(outputs)
-
- # Computing correct paddings post input convolution.
- input_length = paddings.shape[1]
- stride = self.filter_stride[0]
-
- pad_len = (input_length + stride - 1) // stride * stride - input_length
- out_padding = jax.lax.conv_general_dilated(
- lhs=paddings[:, :, None],
- rhs=jnp.ones([1, 1, 1]),
- window_strides=self.filter_stride[:1],
- padding=[(0, pad_len)],
- dimension_numbers=('NHC', 'HIO', 'NHC'))
- out_padding = jnp.squeeze(out_padding, axis=-1)
-
- # Mask outputs by correct paddings to ensure padded elements in inputs map
- # to padded value in outputs.
- outputs = outputs * (1.0 -
- jnp.expand_dims(jnp.expand_dims(out_padding, -1), -1))
-
- return outputs, out_padding
-
-
-class FeedForwardModule(nn.Module):
- """Feedforward block of conformer layer."""
- config: DeepspeechConfig
-
- @nn.compact
- def __call__(self, inputs, input_paddings=None, train=False):
- padding_mask = jnp.expand_dims(1 - input_paddings, -1)
- config = self.config
-
- if config.layernorm_everywhere:
- inputs = LayerNorm(config.encoder_dim)(inputs)
- else:
- inputs = BatchNorm(config.encoder_dim,
- config.dtype,
- config.batch_norm_momentum,
- config.batch_norm_epsilon)(inputs,
- input_paddings,
- train)
- inputs = nn.Dense(
- config.encoder_dim,
- use_bias=True,
- kernel_init=nn.initializers.xavier_uniform())(
- inputs)
- if config.use_tanh:
- inputs = nn.tanh(inputs)
- else:
- inputs = nn.relu(inputs)
- inputs *= padding_mask
-
- if config.feed_forward_dropout_rate is None:
- feed_forward_dropout_rate = 0.1
- else:
- feed_forward_dropout_rate = config.feed_forward_dropout_rate
- inputs = nn.Dropout(rate=feed_forward_dropout_rate)(
- inputs, deterministic=not train)
-
- return inputs
-
-
-class LayerNorm(nn.Module):
- """Module implementing layer normalization.
-
- This implementation is same as in this paper:
- https://arxiv.org/pdf/1607.06450.pdf.
-
- note: we multiply normalized inputs by (1 + scale) and initialize scale to
- zeros, this differs from default flax implementation of multiplying by scale
- and initializing to ones.
- """
- dim: int = 0
- epsilon: float = 1e-6
-
- def setup(self):
- self.scale = self.param('scale', nn.initializers.zeros, [self.dim])
- self.bias = self.param('bias', nn.initializers.zeros, [self.dim])
-
- @nn.compact
- def __call__(self, inputs):
- mean = jnp.mean(inputs, axis=-1, keepdims=True)
- var = jnp.mean(jnp.square(inputs - mean), axis=-1, keepdims=True)
-
- normed_inputs = (inputs - mean) * jax.lax.rsqrt(var + self.epsilon)
- normed_inputs *= (1 + self.scale)
- normed_inputs += self.bias
-
- return normed_inputs
-
-
-class BatchNorm(nn.Module):
- """Implements batch norm respecting input paddings.
-
- This implementation takes into account input padding by masking inputs before
- computing mean and variance.
-
- This is inspired by lingvo jax implementation of BatchNorm:
- https://github.com/tensorflow/lingvo/blob/84b85514d7ad3652bc9720cb45acfab08604519b/lingvo/jax/layers/normalizations.py#L92
-
- and the corresponding defaults for momentum and epsilon have been copied over
- from lingvo.
- """
- encoder_dim: int = 0
- dtype: Any = jnp.float32
- batch_norm_momentum: float = 0.999
- batch_norm_epsilon: float = 0.001
-
- def setup(self):
- dim = self.encoder_dim
- dtype = self.dtype
-
- self.ra_mean = self.variable('batch_stats',
- 'mean',
- lambda s: jnp.zeros(s, dtype),
- dim)
- self.ra_var = self.variable('batch_stats',
- 'var',
- lambda s: jnp.ones(s, dtype),
- dim)
-
- self.gamma = self.param('scale', nn.initializers.zeros, dim, dtype)
- self.beta = self.param('bias', nn.initializers.zeros, dim, dtype)
-
- def _get_default_paddings(self, inputs):
- """Gets the default paddings for an input."""
- in_shape = list(inputs.shape)
- in_shape[-1] = 1
-
- return jnp.zeros(in_shape, dtype=inputs.dtype)
-
- @nn.compact
- def __call__(self, inputs, input_paddings=None, train=False):
- rank = inputs.ndim
- reduce_over_dims = list(range(0, rank - 1))
-
- if input_paddings is None:
- padding = self._get_default_paddings(inputs)
- else:
- padding = jnp.expand_dims(input_paddings, -1)
-
- momentum = self.batch_norm_momentum
- epsilon = self.batch_norm_epsilon
-
- if train:
- mask = 1.0 - padding
- sum_v = jnp.sum(inputs * mask, axis=reduce_over_dims, keepdims=True)
- count_v = jnp.sum(
- jnp.ones_like(inputs) * mask, axis=reduce_over_dims, keepdims=True)
-
- sum_v = jax.lax.psum(sum_v, axis_name='batch')
- count_v = jax.lax.psum(count_v, axis_name='batch')
-
- count_v = jnp.maximum(count_v, 1.0)
- mean = sum_v / count_v
- variance = (inputs - mean) * (inputs - mean) * mask
-
- sum_vv = jnp.sum(variance, axis=reduce_over_dims, keepdims=True)
-
- sum_vv = jax.lax.psum(sum_vv, axis_name='batch')
- var = sum_vv / count_v
-
- self.ra_mean.value = momentum * self.ra_mean.value + (1 - momentum) * mean
- self.ra_var.value = momentum * self.ra_var.value + (1 - momentum) * var
- else:
- mean = self.ra_mean.value
- var = self.ra_var.value
-
- inv = (1 + self.gamma) / jnp.sqrt(var + epsilon)
-
- bn_output = (inputs - mean) * inv + self.beta
- bn_output *= 1.0 - padding
-
- return bn_output
- # return inputs
-
-
-class CudnnLSTM(nn.Module):
- features: int
- num_layers: int = 1
- dropout_rate: float = 0.0
- bidirectional: bool = False
-
- @nn.compact
- def __call__(
- self,
- inputs: Array,
- segmentation_mask: Optional[Array] = None,
- return_carry: Optional[bool] = None,
- deterministic: bool = False,
- initial_states: Optional[Tuple[Array, Array]] = None,
- use_cuda: bool = True,
- ) -> Union[Array, Tuple[Array, Carry]]:
-
- if jax.devices()[0].platform != 'gpu':
- use_cuda = False
-
- batch_size = inputs.shape[0]
- input_size = inputs.shape[2]
- num_directions = 2 if self.bidirectional else 1
- dropout = 0.0 if deterministic else self.dropout_rate
-
- weights = self.param(
- 'weights',
- rnn.init_lstm_weight,
- input_size,
- self.features,
- self.num_layers,
- self.bidirectional,
- )
-
- if initial_states is None:
- h_0 = jnp.zeros(
- (num_directions * self.num_layers, batch_size, self.features),
- jnp.float32,
- )
- c_0 = jnp.zeros(
- (num_directions * self.num_layers, batch_size, self.features),
- jnp.float32,
- )
- else:
- h_0, c_0 = initial_states
-
- if segmentation_mask is not None:
- seq_lengths = jnp.sum(1 - segmentation_mask, axis=1, dtype=jnp.int32)
- else:
- seq_lengths = jnp.full((batch_size,), inputs.shape[1], dtype=jnp.int32)
-
- if use_cuda:
- y, h, c = rnn.lstm(
- x=inputs, h_0=h_0, c_0=c_0, weights=weights,
- seq_lengths=seq_lengths, input_size=input_size,
- hidden_size=self.features, num_layers=self.num_layers,
- dropout=dropout, bidirectional=self.bidirectional,
- )
- else:
- weight_ih, weight_hh, bias_ih, bias_hh = self.unpack_weights(
- weights, input_size)
- y, h, c = rnn.lstm_ref(
- x=inputs, h_0=h_0, c_0=c_0, W_ih=weight_ih, W_hh=weight_hh,
- b_ih=bias_ih, b_hh=bias_hh, seq_lengths=seq_lengths,
- input_size=input_size, hidden_size=self.features,
- num_layers=self.num_layers, dropout=dropout,
- bidirectional=self.bidirectional,
- )
-
- if return_carry:
- return y, (h, c)
-
- return y
-
- @nn.nowrap
- def unpack_weights(
- self, weights: Array, input_size: int
- ) -> Tuple[
- Dict[int, Array], Dict[int, Array], Dict[int, Array], Dict[int, Array]]:
- return jax.experimental.rnn.unpack_lstm_weights(
- weights,
- input_size,
- self.features,
- self.num_layers,
- self.bidirectional,
- )
-
-
-class BatchRNN(nn.Module):
- """Implements a single deepspeech encoder layer.
- """
- config: DeepspeechConfig
-
- @nn.compact
- def __call__(self, inputs, input_paddings, train):
- config = self.config
-
- if config.layernorm_everywhere:
- inputs = LayerNorm(config.encoder_dim)(inputs)
- else:
- inputs = BatchNorm(config.encoder_dim,
- config.dtype,
- config.batch_norm_momentum,
- config.batch_norm_epsilon)(inputs,
- input_paddings,
- train)
- output = CudnnLSTM(
- features=config.encoder_dim // 2,
- bidirectional=config.bidirectional,
- num_layers=1)(inputs, input_paddings)
-
- return output
-
-
-class Deepspeech(nn.Module):
- """Conformer (encoder + decoder) block.
-
- Takes audio input signals and outputs probability distribution over vocab size
- for each time step. The output is then fed into a CTC loss which eliminates
- the need for alignment with targets.
- """
- config: DeepspeechConfig
-
- def setup(self):
- config = self.config
- self.specaug = spectrum_augmenter.SpecAug(
- freq_mask_count=config.freq_mask_count,
- freq_mask_max_bins=config.freq_mask_max_bins,
- time_mask_count=config.time_mask_count,
- time_mask_max_frames=config.time_mask_max_frames,
- time_mask_max_ratio=config.time_mask_max_ratio,
- time_masks_per_frame=config.time_masks_per_frame,
- use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames
- )
-
- @nn.compact
- def __call__(self, inputs, input_paddings, train):
- config = self.config
-
- outputs = inputs
- output_paddings = input_paddings
-
- # Compute normalized log mel spectrograms from input audio signal.
- preprocessing_config = preprocessor.LibrispeechPreprocessingConfig()
- outputs, output_paddings = preprocessor.MelFilterbankFrontend(
- preprocessing_config,
- per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR,
- per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR)(outputs,
- output_paddings)
-
- # Ablate random parts of input along temporal and frequency dimension
- # following the specaug procedure in https://arxiv.org/abs/1904.08779.
- if config.use_specaug and train:
- outputs, output_paddings = self.specaug(outputs, output_paddings)
-
- # Subsample input by a factor of 4 by performing strided convolutions.
- outputs, output_paddings = Subsample(
- config=config)(outputs, output_paddings, train)
-
- # Run the lstm layers.
- for _ in range(config.num_lstm_layers):
- if config.enable_residual_connections:
- outputs = outputs + BatchRNN(config)(outputs, output_paddings, train)
- else:
- outputs = BatchRNN(config)(outputs, output_paddings, train)
-
- for _ in range(config.num_ffn_layers):
- if config.enable_residual_connections:
- outputs = outputs + FeedForwardModule(config=self.config)(
- outputs, output_paddings, train)
- else:
- outputs = FeedForwardModule(config=self.config)(outputs,
- output_paddings,
- train)
-
- # Run the decoder which in this case is a trivial projection layer.
- if config.enable_decoder_layer_norm:
- outputs = LayerNorm(config.encoder_dim)(outputs)
-
- outputs = nn.Dense(
- config.vocab_size,
- use_bias=True,
- kernel_init=nn.initializers.xavier_uniform())(
- outputs)
-
- return outputs, output_paddings
\ No newline at end of file
diff --git a/algoperf/workloads/ogbg/ogbg_jax/models_ref.py b/algoperf/workloads/ogbg/ogbg_jax/models_ref.py
deleted file mode 100644
index ca3d89426..000000000
--- a/algoperf/workloads/ogbg/ogbg_jax/models_ref.py
+++ /dev/null
@@ -1,84 +0,0 @@
-# Forked from the init2winit implementation here
-# https://github.com/google/init2winit/blob/master/init2winit/model_lib/gnn.py.
-from typing import Optional, Tuple
-
-from flax import linen as nn
-import jax.numpy as jnp
-import jraph
-
-
-def _make_embed(latent_dim, name):
-
- def make_fn(inputs):
- return nn.Dense(features=latent_dim, name=name)(inputs)
-
- return make_fn
-
-
-def _make_mlp(hidden_dims, activation_fn, train, dropout_rate):
- """Creates a MLP with specified dimensions."""
-
- @jraph.concatenated_args
- def make_fn(inputs):
- x = inputs
- for dim in hidden_dims:
- x = nn.Dense(features=dim)(x)
- x = nn.LayerNorm()(x)
- x = activation_fn(x)
- x = nn.Dropout(rate=dropout_rate, deterministic=not train)(x)
- return x
-
- return make_fn
-
-
-class GNN(nn.Module):
- """Defines a graph network.
- The model assumes the input data is a jraph.GraphsTuple without global
- variables. The final prediction will be encoded in the globals.
- """
- num_outputs: int
- latent_dim: int = 256
- hidden_dims: Tuple[int] = (256,)
- # If None, defaults to 0.1.
- dropout_rate: Optional[float] = 0.1
- num_message_passing_steps: int = 5
- activation_fn_name: str = 'relu'
-
- @nn.compact
- def __call__(self, graph, train):
- dropout_rate = self.dropout_rate
-
- graph = graph._replace(
- globals=jnp.zeros([graph.n_node.shape[0], self.num_outputs]))
-
- embedder = jraph.GraphMapFeatures(
- embed_node_fn=_make_embed(self.latent_dim, name='node_embedding'),
- embed_edge_fn=_make_embed(self.latent_dim, name='edge_embedding'))
- graph = embedder(graph)
-
- if self.activation_fn_name == 'relu':
- activation_fn = nn.relu
- elif self.activation_fn_name == 'gelu':
- activation_fn = nn.gelu
- elif self.activation_fn_name == 'silu':
- activation_fn = nn.silu
- else:
- raise ValueError(
- f'Invalid activation function name: {self.activation_fn_name}')
-
- for _ in range(self.num_message_passing_steps):
- net = jraph.GraphNetwork(
- update_edge_fn=_make_mlp(
- self.hidden_dims, activation_fn=activation_fn, train=train, dropout_rate=dropout_rate),
- update_node_fn=_make_mlp(
- self.hidden_dims, activation_fn=activation_fn, train=train, dropout_rate=dropout_rate),
- update_global_fn=_make_mlp(
- self.hidden_dims, activation_fn=activation_fn, train=train, dropout_rate=dropout_rate))
-
- graph = net(graph)
-
- # Map globals to represent the final result
- decoder = jraph.GraphMapFeatures(embed_global_fn=nn.Dense(self.num_outputs))
- graph = decoder(graph)
-
- return graph.globals
\ No newline at end of file
diff --git a/algoperf/workloads/wmt/wmt_jax/models_ref.py b/algoperf/workloads/wmt/wmt_jax/models_ref.py
deleted file mode 100644
index e1f44aaa6..000000000
--- a/algoperf/workloads/wmt/wmt_jax/models_ref.py
+++ /dev/null
@@ -1,604 +0,0 @@
-"""Transformer-based machine translation model.
-
-Reference https://github.com/google/flax/tree/main/examples/wmt.
-"""
-
-from typing import Any, Callable, Optional
-
-from flax import linen as nn
-from flax import struct
-from jax import lax
-import jax.numpy as jnp
-import numpy as np
-
-
-@struct.dataclass
-class TransformerConfig:
- """Global hyperparameters used to minimize obnoxious kwarg plumbing."""
- share_embeddings: bool = True
- dtype: Any = jnp.float32
- vocab_size: int = 32000
- emb_dim: int = 1024
- num_heads: int = 16
- num_layers: int = 6
- qkv_dim: int = 1024
- mlp_dim: int = 1024
- max_len: int = 256
- activation: Callable = nn.relu
- glu: bool = False
- #If None, defaults to 0.1.
- dropout_rate: Optional[float] = 0.1
- #If None, defaults to 0.1.
- attention_dropout_rate: Optional[float] = 0.1
- attention_temp: float = 1.0
- deterministic: bool = False
- decode: bool = False
- kernel_init: Callable = nn.initializers.xavier_uniform()
- bias_init: Callable = nn.initializers.normal(stddev=1e-6)
- posemb_init: Optional[Callable] = None
- pre_ln: bool = True
-
-
-def shift_right(x, axis=1):
- """Shift the input to the right by padding on axis 1."""
- pad_widths = [(0, 0)] * len(x.shape)
- pad_widths[axis] = (1, 0)
- padded = jnp.pad(
- x, pad_widths, mode='constant', constant_values=x.dtype.type(0))
- return padded[:, :-1]
-
-
-def sinusoidal_init(max_len=2048, min_scale=1.0, max_scale=10000.0):
- """1D Sinusoidal Position Embedding Initializer.
-
- Args:
- max_len: maximum possible length for the input.
- min_scale: float: minimum frequency-scale in sine grating.
- max_scale: float: maximum frequency-scale in sine grating.
-
- Returns:
- output: init function returning `(1, max_len, d_feature)`
- """
-
- def init(key, shape, dtype=np.float32):
- """Sinusoidal init."""
- del key, dtype
- d_feature = shape[-1]
- pe = np.zeros((max_len, d_feature), dtype=np.float32)
- position = np.arange(0, max_len)[:, np.newaxis]
- scale_factor = -np.log(max_scale / min_scale) / (d_feature // 2 - 1)
- div_term = min_scale * np.exp(np.arange(0, d_feature // 2) * scale_factor)
- pe[:, :d_feature // 2] = np.sin(position * div_term)
- pe[:, d_feature // 2:2 * (d_feature // 2)] = np.cos(position * div_term)
- pe = pe[np.newaxis, :, :] # [1, max_len, d_feature]
- return jnp.array(pe)
-
- return init
-
-
-class AddPositionEmbs(nn.Module):
- """Adds (optionally learned) positional embeddings to the inputs.
-
- Attributes:
- config: TransformerConfig dataclass containing hyperparameters.
- decode: whether to run in single-position autoregressive mode.
- """
- config: TransformerConfig
- decode: bool = False
-
- @nn.compact
- def __call__(self, inputs, inputs_positions=None):
- """Applies AddPositionEmbs module.
-
- By default this layer uses a fixed sinusoidal embedding table. If a
- learned position embedding is desired, pass an initializer to
- posemb_init in the configuration.
-
- Args:
- inputs: input data.
- inputs_positions: input position indices for packed sequences.
-
- Returns:
- output: `(bs, timesteps, in_dim)`
- """
- cfg = self.config
- # inputs.shape is (batch_size, seq_len, emb_dim)
- assert inputs.ndim == 3, ('Number of dimensions should be 3,'
- f' but it is: {inputs.ndim}')
- length = inputs.shape[1]
- pos_emb_shape = (1, cfg.max_len, inputs.shape[-1])
- if cfg.posemb_init is None:
- # Use a fixed (non-learned) sinusoidal position embedding.
- pos_embedding = sinusoidal_init(max_len=cfg.max_len)(None,
- pos_emb_shape,
- None)
- else:
- pos_embedding = self.param('pos_embedding',
- cfg.posemb_init,
- pos_emb_shape)
- pe = pos_embedding[:, :length, :]
-
- # We use a cache position index for tracking decoding position.
- if self.decode:
- is_initialized = self.has_variable('cache', 'cache_index')
- cache_index = self.variable('cache',
- 'cache_index',
- lambda: jnp.array(0, dtype=jnp.uint32))
- if is_initialized:
- i = cache_index.value
- cache_index.value = i + 1
- _, _, df = pos_embedding.shape
- pe = lax.dynamic_slice(pos_embedding, jnp.array((0, i, 0)), (1, 1, df))
- if inputs_positions is None:
- # normal unpacked case:
- return inputs + pe
- else:
- # for packed data we need to use known position indices:
- return inputs + jnp.take(pe[0], inputs_positions, axis=0)
-
-
-class MlpBlock(nn.Module):
- """Transformer MLP / feed-forward block.
-
- Attributes:
- config: TransformerConfig dataclass containing hyperparameters.
- out_dim: optionally specify out dimension.
- """
- config: TransformerConfig
- out_dim: Optional[int] = None
-
- @nn.compact
- def __call__(self, inputs):
- """Applies Transformer MlpBlock module."""
- cfg = self.config
- actual_out_dim = (
- inputs.shape[-1] if self.out_dim is None else self.out_dim)
- x = nn.Dense(
- cfg.mlp_dim,
- dtype=cfg.dtype,
- kernel_init=cfg.kernel_init,
- bias_init=cfg.bias_init)(
- inputs)
- x = cfg.activation(x)
- if cfg.glu:
- y = nn.Dense(
- cfg.mlp_dim,
- dtype=cfg.dtype,
- kernel_init=cfg.kernel_init,
- bias_init=cfg.bias_init)(
- inputs)
- x = x * y
- if cfg.dropout_rate is None:
- dropout_rate = 0.1
- else:
- dropout_rate = cfg.dropout_rate
- x = nn.Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic)
- output = nn.Dense(
- actual_out_dim,
- dtype=cfg.dtype,
- kernel_init=cfg.kernel_init,
- bias_init=cfg.bias_init)(
- x)
- output = nn.Dropout(rate=dropout_rate)(
- output, deterministic=cfg.deterministic)
- return output
-
-
-class Encoder1DBlock(nn.Module):
- """Transformer encoder layer.
-
- Attributes:
- config: TransformerConfig dataclass containing hyperparameters.
- """
- config: TransformerConfig
-
- @nn.compact
- def __call__(self, inputs, encoder_mask=None):
- """Applies Encoder1DBlock module.
-
- Args:
- inputs: input data.
- encoder_mask: encoder self-attention mask.
-
- Returns:
- output after transformer encoder block.
- """
- cfg = self.config
- pre_ln = cfg.pre_ln
-
- # Attention block.
- assert inputs.ndim == 3
- x = nn.LayerNorm(dtype=cfg.dtype)(inputs) if pre_ln else inputs
- if cfg.attention_dropout_rate is None:
- attention_dropout_rate = 0.1
- else:
- attention_dropout_rate = cfg.attention_dropout_rate
- x = nn.MultiHeadDotProductAttention(
- num_heads=cfg.num_heads,
- dtype=cfg.dtype,
- qkv_features=cfg.qkv_dim,
- kernel_init=cfg.kernel_init,
- bias_init=cfg.bias_init,
- use_bias=False,
- broadcast_dropout=False,
- dropout_rate=attention_dropout_rate,
- deterministic=cfg.deterministic)(
- cfg.attention_temp * x, x, mask=encoder_mask)
-
- if cfg.dropout_rate is None:
- dropout_rate = 0.1
- else:
- dropout_rate = cfg.dropout_rate
- x = nn.Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic)
- x = x + inputs
- if not pre_ln:
- x = nn.LayerNorm(dtype=cfg.dtype)(x)
-
- # MLP block.
- y = nn.LayerNorm(dtype=cfg.dtype)(x) if pre_ln else x
- y = MlpBlock(config=cfg)(y)
-
- return x + y if pre_ln else nn.LayerNorm(dtype=cfg.dtype)(x + y)
-
-
-class EncoderDecoder1DBlock(nn.Module):
- """Transformer encoder-decoder layer.
-
- Attributes:
- config: TransformerConfig dataclass containing hyperparameters.
- """
- config: TransformerConfig
-
- @nn.compact
- def __call__(self,
- targets,
- encoded,
- decoder_mask=None,
- encoder_decoder_mask=None):
- """Applies EncoderDecoder1DBlock module.
-
- Args:
- targets: input data for decoder
- encoded: input data from encoder
- decoder_mask: decoder self-attention mask.
- encoder_decoder_mask: encoder-decoder attention mask.
-
- Returns:
- output after transformer encoder-decoder block.
- """
- cfg = self.config
- pre_ln = cfg.pre_ln
-
- # Decoder block.
- assert targets.ndim == 3
- x = nn.LayerNorm(dtype=cfg.dtype)(targets) if pre_ln else targets
-
- if cfg.attention_dropout_rate is None:
- attention_dropout_rate = 0.1
- else:
- attention_dropout_rate = cfg.attention_dropout_rate
- x = nn.MultiHeadDotProductAttention(
- num_heads=cfg.num_heads,
- dtype=cfg.dtype,
- qkv_features=cfg.qkv_dim,
- kernel_init=cfg.kernel_init,
- bias_init=cfg.bias_init,
- use_bias=False,
- broadcast_dropout=False,
- dropout_rate=attention_dropout_rate,
- deterministic=cfg.deterministic,
- decode=cfg.decode)(
- cfg.attention_temp * x, x, mask=decoder_mask)
- if cfg.dropout_rate is None:
- dropout_rate = 0.1
- else:
- dropout_rate = cfg.dropout_rate
- x = nn.Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic)
- x = x + targets
- if not pre_ln:
- x = nn.LayerNorm(dtype=cfg.dtype)(x)
-
- # Encoder-Decoder block.
- y = nn.LayerNorm(dtype=cfg.dtype)(x) if pre_ln else x
- y = nn.MultiHeadDotProductAttention(
- num_heads=cfg.num_heads,
- dtype=cfg.dtype,
- qkv_features=cfg.qkv_dim,
- kernel_init=cfg.kernel_init,
- bias_init=cfg.bias_init,
- use_bias=False,
- broadcast_dropout=False,
- dropout_rate=attention_dropout_rate,
- deterministic=cfg.deterministic)(
- cfg.attention_temp * y, encoded, mask=encoder_decoder_mask)
-
- y = nn.Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic)
- y = y + x
- if not pre_ln:
- y = nn.LayerNorm(dtype=cfg.dtype)(y)
-
- # MLP block.
- z = nn.LayerNorm(dtype=cfg.dtype)(y) if pre_ln else y
- z = MlpBlock(config=cfg)(z)
-
- return y + z if pre_ln else nn.LayerNorm(dtype=cfg.dtype)(y + z)
-
-
-class Encoder(nn.Module):
- """Transformer Model Encoder for sequence to sequence translation.
-
- Attributes:
- config: TransformerConfig dataclass containing hyperparameters.
- shared_embedding: a shared embedding layer to use.
- """
- config: TransformerConfig
- shared_embedding: Any = None
-
- @nn.compact
- def __call__(self, inputs, inputs_positions=None, encoder_mask=None):
- """Applies Transformer model on the inputs.
-
- Args:
- inputs: input data
- inputs_positions: input subsequence positions for packed examples.
- encoder_mask: decoder self-attention mask.
-
- Returns:
- output of a transformer encoder.
- """
- cfg = self.config
- assert inputs.ndim == 2 # (batch, len)
-
- # Input Embedding
- if self.shared_embedding is None:
- input_embed = nn.Embed(
- num_embeddings=cfg.vocab_size,
- features=cfg.emb_dim,
- embedding_init=nn.initializers.normal(stddev=1.0))
- else:
- input_embed = self.shared_embedding
- x = inputs.astype('int32')
- x = input_embed(x)
- x = AddPositionEmbs(
- config=cfg, decode=False, name='posembed_input')(
- x, inputs_positions=inputs_positions)
- if cfg.dropout_rate is None:
- dropout_rate = 0.1
- else:
- dropout_rate = cfg.dropout_rate
- x = nn.Dropout(rate=dropout_rate)(x, deterministic=cfg.deterministic)
-
- x = x.astype(cfg.dtype)
-
- # Input Encoder
- for lyr in range(cfg.num_layers):
- x = Encoder1DBlock(
- config=cfg, name=f'encoderblock_{lyr}')(x, encoder_mask)
-
- encoded = (
- nn.LayerNorm(dtype=cfg.dtype, name='encoder_layernorm')(x)
- if cfg.pre_ln else x)
-
- return encoded
-
-
-class Decoder(nn.Module):
- """Transformer Model Decoder for sequence to sequence translation.
-
- Attributes:
- config: TransformerConfig dataclass containing hyperparameters.
- shared_embedding: a shared embedding layer to use.
- """
- config: TransformerConfig
- shared_embedding: Any = None
-
- @nn.compact
- def __call__(self,
- encoded,
- targets,
- targets_positions=None,
- decoder_mask=None,
- encoder_decoder_mask=None):
- """Applies Transformer model on the inputs.
-
- Args:
- encoded: encoded input data from encoder.
- targets: target inputs.
- targets_positions: input subsequence positions for packed examples.
- decoder_mask: decoder self-attention mask.
- encoder_decoder_mask: encoder-decoder attention mask.
-
- Returns:
- output of a transformer decoder.
- """
- cfg = self.config
-
- assert encoded.ndim == 3 # (batch, len, depth)
- assert targets.ndim == 2 # (batch, len)
-
- # Target Embedding
- if self.shared_embedding is None:
- output_embed = nn.Embed(
- num_embeddings=cfg.vocab_size,
- features=cfg.emb_dim,
- embedding_init=nn.initializers.normal(stddev=1.0))
- else:
- output_embed = self.shared_embedding
-
- y = targets.astype('int32')
- if not cfg.decode:
- y = shift_right(y)
- y = output_embed(y)
- y = AddPositionEmbs(
- config=cfg, decode=cfg.decode, name='posembed_output')(
- y, inputs_positions=targets_positions)
- if cfg.dropout_rate is None:
- dropout_rate = 0.1
- else:
- dropout_rate = cfg.dropout_rate
- y = nn.Dropout(rate=dropout_rate)(y, deterministic=cfg.deterministic)
-
- y = y.astype(cfg.dtype)
-
- # Target-Input Decoder
- for lyr in range(cfg.num_layers):
- y = EncoderDecoder1DBlock(
- config=cfg, name=f'encoderdecoderblock_{lyr}')(
- y,
- encoded,
- decoder_mask=decoder_mask,
- encoder_decoder_mask=encoder_decoder_mask)
- y = (
- nn.LayerNorm(dtype=cfg.dtype, name='encoderdecoder_layernorm')(y)
- if cfg.pre_ln else y)
-
- # Use the transpose of embedding matrix for logit transform.
- logits = output_embed.attend(y.astype(jnp.float32))
- # Correctly normalize pre-softmax logits for this shared case.
- logits = logits / jnp.sqrt(y.shape[-1])
- return logits
-
-
-class Transformer(nn.Module):
- """Transformer Model for sequence to sequence translation.
-
- Attributes:
- config: TransformerConfig dataclass containing hyperparameters.
- """
- config: TransformerConfig
-
- def setup(self):
- cfg = self.config
-
- if cfg.share_embeddings:
- if cfg.vocab_size is not None:
- assert cfg.vocab_size == cfg.vocab_size, (
- "can't share embedding with different vocab sizes.")
- self.shared_embedding = nn.Embed(
- num_embeddings=cfg.vocab_size,
- features=cfg.emb_dim,
- embedding_init=nn.initializers.normal(stddev=1.0))
- else:
- self.shared_embedding = None
-
- self.encoder = Encoder(config=cfg, shared_embedding=self.shared_embedding)
- self.decoder = Decoder(config=cfg, shared_embedding=self.shared_embedding)
-
- def encode(self, inputs, inputs_positions=None, inputs_segmentation=None):
- """Applies Transformer encoder-branch on the inputs.
-
- Args:
- inputs: input data.
- inputs_positions: input subsequence positions for packed examples.
- inputs_segmentation: input segmentation info for packed examples.
-
- Returns:
- encoded feature array from the transformer encoder.
- """
- cfg = self.config
- # Make padding attention mask.
- encoder_mask = nn.make_attention_mask(
- inputs > 0, inputs > 0, dtype=cfg.dtype)
- # Add segmentation block-diagonal attention mask if using segmented data.
- if inputs_segmentation is not None:
- encoder_mask = nn.combine_masks(
- encoder_mask,
- nn.make_attention_mask(
- inputs_segmentation,
- inputs_segmentation,
- jnp.equal,
- dtype=cfg.dtype))
- return self.encoder(
- inputs, inputs_positions=inputs_positions, encoder_mask=encoder_mask)
-
- def decode(
- self,
- encoded,
- inputs, # only needed for masks
- targets,
- targets_positions=None,
- inputs_segmentation=None,
- targets_segmentation=None):
- """Applies Transformer decoder-branch on encoded-input and target.
-
- Args:
- encoded: encoded input data from encoder.
- inputs: input data (only needed for masking).
- targets: target data.
- targets_positions: target subsequence positions for packed examples.
- inputs_segmentation: input segmentation info for packed examples.
- targets_segmentation: target segmentation info for packed examples.
-
- Returns:
- logits array from transformer decoder.
- """
- cfg = self.config
-
- # Make padding attention masks.
- if cfg.decode:
- decoder_mask = None
- encoder_decoder_mask = nn.make_attention_mask(
- jnp.ones_like(targets) > 0, inputs > 0, dtype=cfg.dtype)
- else:
- decoder_mask = nn.combine_masks(
- nn.make_attention_mask(targets > 0, targets > 0, dtype=cfg.dtype),
- nn.make_causal_mask(targets, dtype=cfg.dtype))
- encoder_decoder_mask = nn.make_attention_mask(
- targets > 0, inputs > 0, dtype=cfg.dtype)
-
- # Add segmentation block-diagonal attention masks if using segmented data.
- if inputs_segmentation is not None:
- decoder_mask = nn.combine_masks(
- decoder_mask,
- nn.make_attention_mask(
- targets_segmentation,
- targets_segmentation,
- jnp.equal,
- dtype=cfg.dtype))
- encoder_decoder_mask = nn.combine_masks(
- encoder_decoder_mask,
- nn.make_attention_mask(
- targets_segmentation,
- inputs_segmentation,
- jnp.equal,
- dtype=cfg.dtype))
- logits = self.decoder(
- encoded,
- targets,
- targets_positions=targets_positions,
- decoder_mask=decoder_mask,
- encoder_decoder_mask=encoder_decoder_mask)
- return logits.astype(self.config.dtype)
-
- def __call__(self,
- inputs,
- targets,
- inputs_positions=None,
- targets_positions=None,
- inputs_segmentation=None,
- targets_segmentation=None):
- """Applies Transformer model on the inputs.
-
- Args:
- inputs: input data.
- targets: target data.
- inputs_positions: input subsequence positions for packed examples.
- targets_positions: target subsequence positions for packed examples.
- inputs_segmentation: input segmentation info for packed examples.
- targets_segmentation: target segmentation info for packed examples.
-
- Returns:
- logits array from full transformer.
- """
- encoded = self.encode(
- inputs,
- inputs_positions=inputs_positions,
- inputs_segmentation=inputs_segmentation)
-
- return self.decode(
- encoded,
- inputs, # only used for masks
- targets,
- targets_positions=targets_positions,
- inputs_segmentation=inputs_segmentation,
- targets_segmentation=targets_segmentation)
\ No newline at end of file
From 161c264c7c007f9d3f065f2affb55091cc61e492 Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Wed, 25 Jun 2025 05:32:28 +0000
Subject: [PATCH 088/123] lint fix
---
algoperf/workloads/ogbg/ogbg_jax/models.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/algoperf/workloads/ogbg/ogbg_jax/models.py b/algoperf/workloads/ogbg/ogbg_jax/models.py
index 7207e033d..8524bb60e 100644
--- a/algoperf/workloads/ogbg/ogbg_jax/models.py
+++ b/algoperf/workloads/ogbg/ogbg_jax/models.py
@@ -93,4 +93,4 @@ def __call__(self, graph, train, dropout_rate=DROPOUT_RATE):
decoder = jraph.GraphMapFeatures(embed_global_fn=nn.Dense(self.num_outputs))
graph = decoder(graph)
- return graph.globals
\ No newline at end of file
+ return graph.globals
From ac76d4ff7258fba7882d3e414c4b7f97fc0627de Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Wed, 25 Jun 2025 05:47:05 +0000
Subject: [PATCH 089/123] formatting fixes
---
tests/test_jax_utils.py | 43 ++++++++++++++++++++---------------------
1 file changed, 21 insertions(+), 22 deletions(-)
diff --git a/tests/test_jax_utils.py b/tests/test_jax_utils.py
index 8a156149b..7ee61d1e1 100644
--- a/tests/test_jax_utils.py
+++ b/tests/test_jax_utils.py
@@ -4,7 +4,6 @@
"""
from functools import partial
-import os
from absl.testing import absltest
from absl.testing import parameterized
@@ -63,7 +62,7 @@ def __call__(self, x, train, dropout_rate=DEFAULT_DROPOUT):
x, rate=dropout_rate)
-class ModelEquivalenceTest(parameterized.TestCase):
+class DropoutTest(parameterized.TestCase):
@parameterized.named_parameters(
dict(
@@ -185,8 +184,7 @@ def test_dropout_update(self, dropout_rate, mode):
dict(testcase_name="Dropout, p=0.1, eval", dropout_rate=0.1, mode="eval"),
)
def test_jitted_updates(self, dropout_rate, mode):
- """ Compare forward pass of Dropout layer to flax.linen.Dropout in train and
- eval mode.
+ """ Compare jitted updates with dropout.
"""
# initialize models
@@ -214,24 +212,25 @@ def test_jitted_updates(self, dropout_rate, mode):
jitted_custom_apply = jax.jit(
partial(cust_model.apply), static_argnames=['train'])
- def multiple_fwd_passes_custom_layer():
- for d in [i * 0.1 * dropout_rate for i in range(0, 11)]:
- y2 = jitted_custom_apply(
- initial_variables_custom,
- x,
- train=train,
- dropout_rate=d,
- rngs={"dropout": dropout_rng},
- )
- return y2
-
- def multiple_fwd_passes_original_layer():
- for d in [i * 0.1 * dropout_rate for i in range(0, 11)]:
- y1 = jitted_original_apply(
- initial_variables_original,
- x,
- train=train,
- rngs={"dropout": dropout_rng})
+ for d in [i * 0.1 * dropout_rate for i in range(0, 11)]:
+ y1 = jitted_original_apply(
+ initial_variables_original,
+ x,
+ train=train,
+ rngs={"dropout": dropout_rng})
+ return y1
+
+ for d in [i * 0.1 * dropout_rate for i in range(0, 11)]:
+ y2 = jitted_custom_apply(
+ initial_variables_custom,
+ x,
+ train=train,
+ dropout_rate=d,
+ rngs={"dropout": dropout_rng},
+ )
+ return y2
+
+ assert jnp.allclose(y1, y2, atol=1e-3, rtol=1e-3)
if __name__ == "__main__":
From e4eaceaf57e32b74cafd2a31e16cc64bb5a354ab Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Wed, 25 Jun 2025 06:10:58 +0000
Subject: [PATCH 090/123] fix linting
---
tests/test_jax_utils.py | 11 ++++-------
1 file changed, 4 insertions(+), 7 deletions(-)
diff --git a/tests/test_jax_utils.py b/tests/test_jax_utils.py
index 7ee61d1e1..a90b44f62 100644
--- a/tests/test_jax_utils.py
+++ b/tests/test_jax_utils.py
@@ -22,13 +22,12 @@
def pytrees_are_equal(a, b, rtol=1e-5, atol=1e-8):
"""
- A custom function to check if two PyTrees are equal, handling floats with a tolerance.
+ A custom function to check if two PyTrees are equal, handling floats with
+ a tolerance.
"""
- # 1. Check if the structures are the same
if tree_structure(a) != tree_structure(b):
return False
- # 2. Define a comparison function for leaves
def leaf_comparator(x, y):
# Use allclose for floating-point JAX arrays
if isinstance(x, jnp.ndarray) and jnp.issubdtype(x.dtype, jnp.floating):
@@ -37,8 +36,6 @@ def leaf_comparator(x, y):
else:
return x == y
- # 3. Map the comparison function over the trees and check if all results are True
- # We also need to flatten the results of the tree_map and check if all are True
comparison_tree = tree_map(leaf_comparator, a, b)
all_equal = all(tree_leaves(comparison_tree))
@@ -80,7 +77,7 @@ def test_forward(self, dropout_rate, mode):
"""
# initialize models
- rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3)
+ rng, dropout_rng = jax.random.split(jax.random.key(SEED), 2)
fake_batch = jnp.ones((10,))
orig_model = LegacyDropoutModel(dropout_rate=dropout_rate)
cust_model = DropoutModel()
@@ -130,7 +127,7 @@ def test_dropout_update(self, dropout_rate, mode):
eval mode.
"""
# init model
- rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3)
+ rng, dropout_rng = jax.random.split(jax.random.key(SEED), 2)
fake_batch = jnp.ones((10,))
orig_model = LegacyDropoutModel(dropout_rate=dropout_rate)
cust_model = DropoutModel()
From 0f430495f6ecd6ab6efcebcbd99bbdb7ca04f018 Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Wed, 25 Jun 2025 06:11:35 +0000
Subject: [PATCH 091/123] fix linting
---
tests/test_jax_utils.py | 3 ---
1 file changed, 3 deletions(-)
diff --git a/tests/test_jax_utils.py b/tests/test_jax_utils.py
index a90b44f62..22a3d0929 100644
--- a/tests/test_jax_utils.py
+++ b/tests/test_jax_utils.py
@@ -215,7 +215,6 @@ def test_jitted_updates(self, dropout_rate, mode):
x,
train=train,
rngs={"dropout": dropout_rng})
- return y1
for d in [i * 0.1 * dropout_rate for i in range(0, 11)]:
y2 = jitted_custom_apply(
@@ -225,8 +224,6 @@ def test_jitted_updates(self, dropout_rate, mode):
dropout_rate=d,
rngs={"dropout": dropout_rng},
)
- return y2
-
assert jnp.allclose(y1, y2, atol=1e-3, rtol=1e-3)
From a151382d5c61ce29f80003869cf0be83f374aadb Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Wed, 25 Jun 2025 06:20:47 +0000
Subject: [PATCH 092/123] pylint fix
---
tests/test_jax_utils.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/tests/test_jax_utils.py b/tests/test_jax_utils.py
index 22a3d0929..d54bf47aa 100644
--- a/tests/test_jax_utils.py
+++ b/tests/test_jax_utils.py
@@ -185,7 +185,7 @@ def test_jitted_updates(self, dropout_rate, mode):
"""
# initialize models
- rng, data_rng, dropout_rng = jax.random.split(jax.random.key(SEED), 3)
+ rng, dropout_rng = jax.random.split(jax.random.key(SEED), 2)
fake_batch = jnp.ones((10,))
orig_model = LegacyDropoutModel(dropout_rate=dropout_rate)
cust_model = DropoutModel()
From 78a14095273f38f03697005b37e7365560bafe55 Mon Sep 17 00:00:00 2001
From: Frank Schneider
Date: Wed, 18 Jun 2025 14:17:23 +0200
Subject: [PATCH 093/123] Formatting, ignore `.eggs/` in yapf
---
pyproject.toml | 19 +++++--------------
1 file changed, 5 insertions(+), 14 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 1daa72848..f72491380 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -46,7 +46,6 @@ dependencies = [
"clu==0.0.12",
"matplotlib>=3.9.2",
"tabulate==0.9.0",
-
]
[build-system]
@@ -69,9 +68,7 @@ version_file = "algoperf/_version.py"
###############################################################################
[project.optional-dependencies]
# All workloads
-full = [
- "algoperf[criteo1tb,fastmri,ogbg,librispeech_conformer,wmt]",
-]
+full = ["algoperf[criteo1tb,fastmri,ogbg,librispeech_conformer,wmt]"]
# All workloads plus development dependencies
full_dev = ["algoperf[full,dev]"]
# Dependencies for developing the package
@@ -104,11 +101,7 @@ jax_core_deps = [
"ml_dtypes==0.4.1",
"protobuf==4.25.5",
]
-jax_cpu = [
- "jax==0.4.28",
- "jaxlib==0.4.28",
- "algoperf[jax_core_deps]",
-]
+jax_cpu = ["jax==0.4.28", "jaxlib==0.4.28", "algoperf[jax_core_deps]"]
jax_gpu = [
"jax==0.4.28",
"jaxlib==0.4.28",
@@ -117,10 +110,8 @@ jax_gpu = [
"algoperf[jax_core_deps]",
]
pytorch_cpu = ["torch==2.5.1", "torchvision==0.20.1"]
-pytorch_gpu = [
- "torch==2.5.1",
- "torchvision==0.20.1",
-] # Note: omit the cuda suffix and installing from the appropriate wheel will result in using locally installed CUDA.
+pytorch_gpu = ["torch==2.5.1", "torchvision==0.20.1"]
+# Note: omit the cuda suffix and installing from the appropriate wheel will result in using locally installed CUDA.
###############################################################################
# Linting Configurations #
@@ -133,7 +124,7 @@ each_dict_entry_on_separate_line = false
split_all_top_level_comma_separated_values = true
column_limit = 80
[tool.yapfignore]
-ignore_patterns = ["algoperf/_version.py"]
+ignore_patterns = ["algoperf/_version.py", ".eggs/**"]
# isort configuration
[tool.isort]
From 462e8b7988ab722682e0f2ceb18785ec2028cd10 Mon Sep 17 00:00:00 2001
From: Frank Schneider
Date: Wed, 18 Jun 2025 15:44:01 +0200
Subject: [PATCH 094/123] Replace yapf, pylint, isort with ruff
---
pyproject.toml | 258 ++++++++-----------------------------------------
1 file changed, 39 insertions(+), 219 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index f72491380..1532a66c7 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -27,9 +27,9 @@ classifiers = [
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
- "Programming Language :: Python :: 3.8",
- "Programming Language :: Python :: 3.9",
- "Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
+ "Programming Language :: Python :: 3.12",
+ "Programming Language :: Python :: 3.13",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
]
dependencies = [
@@ -72,13 +72,7 @@ full = ["algoperf[criteo1tb,fastmri,ogbg,librispeech_conformer,wmt]"]
# All workloads plus development dependencies
full_dev = ["algoperf[full,dev]"]
# Dependencies for developing the package
-dev = [
- "isort==5.12.0",
- "pylint==2.17.4",
- "pytest==8.3.3",
- "yapf==0.32.0",
- "pre-commit==4.0.1",
-]
+dev = ["ruff==0.12.0", "pytest==8.3.3", "pre-commit==4.0.1"]
wandb = ["wandb==0.19.6"]
@@ -111,215 +105,41 @@ jax_gpu = [
]
pytorch_cpu = ["torch==2.5.1", "torchvision==0.20.1"]
pytorch_gpu = ["torch==2.5.1", "torchvision==0.20.1"]
-# Note: omit the cuda suffix and installing from the appropriate wheel will result in using locally installed CUDA.
+# Note: omit the cuda suffix and installing from the appropriate wheel
+# will result in using locally installed CUDA.
###############################################################################
-# Linting Configurations #
+# Linting & Formatting Configurations #
###############################################################################
-
-# yapf configuration
-[tool.yapf]
-based_on_style = "yapf"
-each_dict_entry_on_separate_line = false
-split_all_top_level_comma_separated_values = true
-column_limit = 80
-[tool.yapfignore]
-ignore_patterns = ["algoperf/_version.py", ".eggs/**"]
-
-# isort configuration
-[tool.isort]
-profile = "google"
-
-# pylint configuration
-[tool.pylint.MASTER]
-persistent = false
-ignore = "get_references_web.py,get_references_web_single_group.py,_version.py"
-
-[tool.pylint.REPORTS]
-reports = false
-msg-template = "{msg_id}:{line:3} {obj}: {msg} [{symbol}]"
-
-[tool.pylint.MESSAGES_CONTROL]
-enable = "indexing-exception,old-raise-syntax"
-
-[tool.pylint.BASIC]
-# Required attributes for module, separated by a comma
-#required-attributes=
-# Regular expression which should only match the name
-# of functions or classes which do not require a docstring.
-no-docstring-rgx = "(__.*__|main)"
-# Min length in lines of a function that requires a docstring.
-docstring-min-length = 10
-# Regular expression which should only match correct module names. The
-# leading underscore is sanctioned for private modules by Google's style
-# guide.
-#
-# There are exceptions to the basic rule (_?[a-z][a-z0-9_]*) to cover
-# requirements of Python's module system.
-module-rgx = "^(_?[a-z][a-z0-9_]*)|__init__$"
-# Regular expression which should only match correct module level names
-const-rgx = "^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$"
-# Regular expression which should only match correct class attribute
-class-attribute-rgx = "^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$"
-# Regular expression which should only match correct class names
-class-rgx = "^_?[A-Z][a-zA-Z0-9]*$"
-# Regular expression which should only match correct function names.
-# 'camel_case' and 'snake_case' group names are used for consistency of naming
-# styles across functions and methods.
-function-rgx = "^(?:(?PsetUp|tearDown|setUpModule|tearDownModule)|(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$"
-# Regular expression which should only match correct method names.
-# 'camel_case' and 'snake_case' group names are used for consistency of naming
-# styles across functions and methods. 'exempt' indicates a name which is
-# consistent with all naming styles.
-method-rgx = "(?x)^(?:(?P_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|_testDatasetSize|setUpClass|test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|(?:test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next|(?P_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P_{0,2}[a-z][a-z0-9_]*))$"
-# Regular expression which should only match correct instance attribute names
-attr-rgx = "^_{0,2}[a-z][a-z0-9_]*$"
-# Regular expression which should only match correct argument names
-argument-rgx = "^[a-z][a-z0-9_]*$"
-# Regular expression which should only match correct variable names
-variable-rgx = "^[a-z][a-z0-9_]*$"
-# Regular expression which should only match correct list comprehension /
-# generator expression variable names
-inlinevar-rgx = "^[a-z][a-z0-9_]*$"
-# Good variable names which should always be accepted, separated by a comma
-good-names = "main,_"
-# Bad variable names which should always be refused, separated by a comma
-bad-names = ""
-# List of builtins function names that should not be used, separated by a comma
-#bad-functions=input,apply,reduce
-# List of decorators that define properties, such as abc.abstractproperty.
-property-classes = "abc.abstractproperty"
-
-[tool.pylint.typecheck]
-# Tells whether missing members accessed in mixin class should be ignored. A
-# mixin class is detected if its name ends with "mixin" (case insensitive).
-ignore-mixin-members = true
-
-# List of decorators that create context managers from functions, such as
-# contextlib.contextmanager.
-contextmanager-decorators = [
- "contextlib.contextmanager",
- "contextlib2.contextmanager",
-]
-
-[tool.pylint.VARIABLES]
-# Tells whether we should check for unused import in __init__ files.
-init-import = false
-
-# A regular expression matching names used for dummy variables (i.e. not used).
-dummy-variables-rgx = "^\\*{0,2}(_$|unused_|dummy_)"
-
-# List of additional names supposed to be defined in builtins.
-additional-builtins = []
-
-[tool.pylint.CLASSES]
-# List of method names used to declare (i.e. assign) instance attributes.
-defining-attr-methods = ["__init__", "__new__", "setUp"]
-
-# Valid names for the first argument to a class method.
-valid-classmethod-first-arg = ["cls", "class_"]
-
-[tool.pylint.EXCEPTIONS]
-overgeneral-exceptions = [
- "builtins.StandardError",
- "builtins.Exception",
- "builtins.BaseException",
-]
-
-[tool.pylint.IMPORTS]
-# Deprecated modules which should not be used, separated by a comma
-deprecated-modules = ["regsub", "TERMIOS", "Bastion", "rexec", "sets"]
-
-[tool.pylint.FORMAT]
-# List of checkers and warnings to disable.
-disable = [
- "abstract-method",
- "access-member-before-definition",
- "arguments-differ",
- "assignment-from-no-return",
- "attribute-defined-outside-init",
- "bad-mcs-classmethod-argument",
- "bad-option-value",
- "c-extension-no-member",
- "consider-merging-isinstance",
- "consider-using-dict-comprehension",
- "consider-using-enumerate",
- "consider-using-in",
- "consider-using-set-comprehension",
- "consider-using-ternary",
- "deprecated-method",
- "design",
- "file-ignored",
- "fixme",
- "global-statement",
- "import-error",
- "inconsistent-return-statements",
- "invalid-unary-operand-type",
- "len-as-condition",
- "locally-disabled",
- "locally-enabled",
- "misplaced-comparison-constant",
- "missing-docstring",
- "multiple-imports",
- "no-else-return",
- "no-member",
- "no-name-in-module",
- "no-self-use",
- "no-value-for-parameter",
- "not-an-iterable",
- "not-context-manager",
- "pointless-except",
- "protected-access",
- "redefined-argument-from-local",
- "signature-differs",
- "similarities",
- "simplifiable-if-expression",
- "star-args",
- "super-init-not-called",
- "suppressed-message",
- "too-many-function-args",
- "trailing-comma-tuple",
- "trailing-newlines",
- "ungrouped-imports",
- "unnecessary-pass",
- "unsubscriptable-object",
- "unused-argument",
- "useless-object-inheritance",
- "useless-return",
- "useless-suppression",
- "wrong-import-order",
- "wrong-import-position",
- "unneeded-not",
- "unexpected-keyword-arg",
- "redundant-keyword-arg",
- "unspecified-encoding",
- "logging-fstring-interpolation",
- "consider-using-f-string",
- "use-dict-literal",
-]
-# Maximum number of characters on a single line.
-max-line-length = 80
-ignore-long-lines = "(?x)(^\\s*(import|from)\\s|^\\s*(\\#\\ )?(https?|ftp):\\/\\/[^\\s\\/$.?#].[^\\s]*>?$|^[a-zA-Z_][a-zA-Z0-9_]*\\s*=\\s*('[^']\\S+'|\"[^\"]\\S+\"))"
-# Maximum number of lines in a module
-max-module-lines = 99999
-# String used as indentation unit. We differ from PEP8's normal 4 spaces.
-indent-string = ' '
-single-line-if-stmt = true
-# Do not warn about multiple statements on a single line for constructs like
-# if test: stmt
-[tool.pylint.LOGGING]
-logging-modules = "logging,absl.logging"
-# Add logging modules.
-[tool.pylint.MISCELLANEOUS]
-# Maximum line length for lambdas
-#short-func-length=1
-# List of module members that should be marked as deprecated.
-# All of the string functions are listed in 4.1.4 Deprecated string functions
-# in the Python 2.4 docs.
-#deprecated-members=string.atof,string.atoi,string.atol,string.capitalize,string.expandtabs,string.find,string.rfind,string.index,string.rindex,string.count,string.lower,string.split,string.rsplit,string.splitfields,string.join,string.joinfields,string.lstrip,string.rstrip,string.strip,string.swapcase,string.translate,string.upper,string.ljust,string.rjust,string.center,string.zfill,string.replace,sys.exitfunc,sys.maxint
-# List of exceptions that do not need to be mentioned in the Raises section of
-# a docstring.
-#ignore-exceptions=AssertionError,NotImplementedError,StopIteration,TypeError
-# Number of spaces of indent required when the last token on the preceding line
-# is an open (, [, or {.
-indent-after-paren = 4
+[tool.ruff]
+line-length = 80
+indent-width = 2
+exclude = ["_version.py"]
+target-version = "py311"
+
+[tool.ruff.format]
+quote-style = "single"
+
+[tool.ruff.lint]
+select = ["I", "PL"]
+ignore = [
+ # Conflicting lint rules with Ruff's formatter
+ # (see https://docs.astral.sh/ruff/formatter/#conflicting-lint-rules).
+ "W191",
+ "E111",
+ "E114",
+ "E117",
+ "D206",
+ "D300",
+ "Q000",
+ "Q001",
+ "Q002",
+ "Q003",
+ "COM812",
+ "COM819",
+ "ISC001",
+ "ISC002",
+ "FBT001",
+ "FBT003",
+ "TD003",
+]
\ No newline at end of file
From 383db7a0f4a463537b1493a56808628db6f238b9 Mon Sep 17 00:00:00 2001
From: Frank Schneider
Date: Wed, 18 Jun 2025 15:52:11 +0200
Subject: [PATCH 095/123] Replace pre-commit with ruff
---
.pre-commit-config.yaml | 20 ++++++++------------
1 file changed, 8 insertions(+), 12 deletions(-)
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index cc8f13d25..f2d684c53 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,14 +1,10 @@
repos:
- - repo: https://github.com/google/yapf
- rev: v0.32.0
+ - repo: https://github.com/astral-sh/ruff-pre-commit
+ # Ruff version.
+ rev: v0.12.0
hooks:
- - id: yapf
- args: ["--in-place", "--parallel", "--verbose", "--recursive"]
- - repo: https://github.com/pycqa/isort
- rev: 5.10.1
- hooks:
- - id: isort
- - repo: https://github.com/pycqa/pylint
- rev: v2.16.1
- hooks:
- - id: pylint
+ # Run the linter (don't change files).
+ - id: ruff-check
+ # Run the formatter (don't change files).
+ - id: ruff-format
+ args: ["--check"]
From 2c2813628e8c8e4c9942012b045ea8188db7e941 Mon Sep 17 00:00:00 2001
From: Frank Schneider
Date: Wed, 18 Jun 2025 16:08:54 +0200
Subject: [PATCH 096/123] Use extend-select instead, and reduce lint rules
---
pyproject.toml | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/pyproject.toml b/pyproject.toml
index 1532a66c7..2db40f61f 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -121,7 +121,8 @@ target-version = "py311"
quote-style = "single"
[tool.ruff.lint]
-select = ["I", "PL"]
+extend-select = ["I"]
+# Could add (in the future): "E", "F", "UP", "B", "SIM", "PL"
ignore = [
# Conflicting lint rules with Ruff's formatter
# (see https://docs.astral.sh/ruff/formatter/#conflicting-lint-rules).
From 999f7a2f280f3c91c8f29cd4652200f8d63dae24 Mon Sep 17 00:00:00 2001
From: Frank Schneider
Date: Wed, 18 Jun 2025 16:09:09 +0200
Subject: [PATCH 097/123] Replace linting GH actions with ruff
---
.github/workflows/linting.yml | 39 +++++++++--------------------------
1 file changed, 10 insertions(+), 29 deletions(-)
diff --git a/.github/workflows/linting.yml b/.github/workflows/linting.yml
index 992496b69..4308e4201 100644
--- a/.github/workflows/linting.yml
+++ b/.github/workflows/linting.yml
@@ -3,7 +3,7 @@ name: Linting
on: [push, pull_request]
jobs:
- pylint:
+ ruff-linting:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
@@ -11,19 +11,15 @@ jobs:
uses: actions/setup-python@v2
with:
python-version: 3.11.10
- - name: Install pylint
+ - name: Install ruff
run: |
python -m pip install --upgrade pip
- pip install pylint==2.16.1
- - name: Run pylint
+ pip install ruff==0.12.0
+ - name: Run ruff linter
run: |
- pylint algoperf
- pylint reference_algorithms
- pylint prize_qualification_baselines
- pylint submission_runner.py
- pylint tests
+ ruff check
- isort:
+ ruff-formatter:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
@@ -31,26 +27,11 @@ jobs:
uses: actions/setup-python@v2
with:
python-version: 3.11.10
- - name: Install isort
+ - name: Install ruff
run: |
python -m pip install --upgrade pip
- pip install isort==5.12.0
- - name: Run isort
+ pip install ruff==0.12.0
+ - name: Run ruff formatter
run: |
- isort . --check --diff
+ ruff format --check
- yapf:
- runs-on: ubuntu-latest
- steps:
- - uses: actions/checkout@v2
- - name: Set up Python 3.11.10
- uses: actions/setup-python@v2
- with:
- python-version: 3.11.10
- - name: Install yapf
- run: |
- python -m pip install --upgrade pip
- pip install yapf==0.32 toml
- - name: Run yapf
- run: |
- yapf . --diff --recursive
From 7b245d6a6692509b4e994106a1c3587a640e4165 Mon Sep 17 00:00:00 2001
From: Frank Schneider
Date: Wed, 18 Jun 2025 16:41:28 +0200
Subject: [PATCH 098/123] Add ruff badge
---
README.md | 39 ++++++++++++++++++---------------------
1 file changed, 18 insertions(+), 21 deletions(-)
diff --git a/README.md b/README.md
index 8e470266d..0666d21d5 100644
--- a/README.md
+++ b/README.md
@@ -14,11 +14,11 @@
Benchmark/Results Paper
-[](https://github.com/mlcommons/algorithmic-efficiency/actions/workflows/CI.yml)
-[](https://github.com/mlcommons/algorithmic-efficiency/actions/workflows/linting.yml)
-[](https://github.com/mlcommons/algorithmic-efficiency/blob/main/LICENSE.md)
-[](https://github.com/google/yapf)
-[](https://discord.gg/5FPXK7SMt6)
+[](https://github.com/mlcommons/algorithmic-efficiency/actions/workflows/CI.yml)
+[](https://github.com/mlcommons/algorithmic-efficiency/actions/workflows/linting.yml)
+[](https://github.com/astral-sh/ruff)
+[](LICENSE.md)
+[](https://discord.gg/5FPXK7SMt6)
---
@@ -28,11 +28,12 @@ Submissions are evaluated based on their "time-to-result", i.e., the wall-clock
---
-> This is the repository for the *AlgoPerf: Training Algorithms benchmark* measuring neural network training speedups due to algorithmic improvements.
+> This is the repository for the _AlgoPerf: Training Algorithms benchmark_ measuring neural network training speedups due to algorithmic improvements.
> It is developed by the [MLCommons Algorithms Working Group](https://mlcommons.org/en/groups/research-algorithms/).
> This repository holds the benchmark code, the benchmark's [**technical documentation**](/docs/DOCUMENTATION.md) and [**getting started guides**](/docs/GETTING_STARTED.md). For a detailed description of the benchmark design, see our [**introductory paper**](https://arxiv.org/abs/2306.07179), for the results of the inaugural competition see our [**results paper**](https://openreview.net/forum?id=CtM5xjRSfm).
>
> **See our [AlgoPerf Leaderboard](https://github.com/mlcommons/submissions_algorithms) for the latest results of the benchmark and to submit your algorithm.**
+
---
> [!IMPORTANT]
@@ -50,14 +51,13 @@ Submissions are evaluated based on their "time-to-result", i.e., the wall-clock
## Installation
-> [!TIP]
-> **If you have any questions about the benchmark competition or you run into any issues, please feel free to contact us.** Either [file an issue](https://github.com/mlcommons/algorithmic-efficiency/issues), ask a question on [our Discord](https://discord.gg/5FPXK7SMt6) or [join our weekly meetings](https://mlcommons.org/en/groups/research-algorithms/).
+> [!TIP] > **If you have any questions about the benchmark competition or you run into any issues, please feel free to contact us.** Either [file an issue](https://github.com/mlcommons/algorithmic-efficiency/issues), ask a question on [our Discord](https://discord.gg/5FPXK7SMt6) or [join our weekly meetings](https://mlcommons.org/en/groups/research-algorithms/).
You can install this package and dependencies in a [Python virtual environment](/docs/GETTING_STARTED.md#python-virtual-environment) or use a [Docker/Singularity/Apptainer container](/docs/GETTING_STARTED.md#docker) (recommended).
We recommend using a Docker container (or alternatively, a Singularity/Apptainer container) to ensure a similar environment to our scoring and testing environments.
Both options are described in detail in the [**Getting Started**](/docs/GETTING_STARTED.md) document.
-*TL;DR to install the Jax version for GPU run:*
+_TL;DR to install the Jax version for GPU run:_
```bash
pip3 install -e '.[pytorch_cpu]'
@@ -65,7 +65,7 @@ pip3 install -e '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax
pip3 install -e '.[full]'
```
-*TL;DR to install the PyTorch version for GPU run:*
+_TL;DR to install the PyTorch version for GPU run:_
```bash
pip3 install -e '.[jax_cpu]'
@@ -77,7 +77,7 @@ pip3 install -e '.[full]'
For detailed instructions on developing your own algorithm in the benchmark see the [Getting Started](/docs/GETTING_STARTED.md) document.
-*TL;DR running a JAX workload:*
+_TL;DR running a JAX workload:_
```bash
python3 submission_runner.py \
@@ -89,7 +89,7 @@ python3 submission_runner.py \
--tuning_search_space=reference_algorithms/paper_baselines/adamw/tuning_search_space.json
```
-*TL;DR running a PyTorch workload:*
+_TL;DR running a PyTorch workload:_
```bash
python3 submission_runner.py \
@@ -117,17 +117,15 @@ Our [**Contributing**](/docs/CONTRIBUTING.md) document provides further MLCommon
## License
-The *AlgoPerf* codebase is licensed under the [Apache License 2.0](/LICENSE.md).
+The _AlgoPerf_ codebase is licensed under the [Apache License 2.0](/LICENSE.md).
## Paper and Citing the AlgoPerf Benchmark
-In our paper ["Benchmarking Neural Network Training Algorithms"](http://arxiv.org/abs/2306.07179) we motivate, describe, and justify the *AlgoPerf: Training Algorithms* benchmark.
+In our paper ["Benchmarking Neural Network Training Algorithms"](http://arxiv.org/abs/2306.07179) we motivate, describe, and justify the _AlgoPerf: Training Algorithms_ benchmark.
-If you are using the *AlgoPerf benchmark*, its codebase, baselines, or workloads, please consider citing our paper:
+If you are using the _AlgoPerf benchmark_, its codebase, baselines, or workloads, please consider citing our paper:
-> [Dahl, Schneider, Nado, et al.
-> **Benchmarking Neural Network Training Algorithms**
-> *arXiv 2306.07179*](http://arxiv.org/abs/2306.07179)
+> [Dahl, Schneider, Nado, et al.
> **Benchmarking Neural Network Training Algorithms**
> _arXiv 2306.07179_](http://arxiv.org/abs/2306.07179)
```bibtex
@Misc{Dahl2023AlgoPerf,
@@ -139,10 +137,9 @@ If you are using the *AlgoPerf benchmark*, its codebase, baselines, or workloads
}
```
-If you use the results from the first *AlgoPerf competition*, please consider citing the results paper, as well as the relevant submissions:
+If you use the results from the first _AlgoPerf competition_, please consider citing the results paper, as well as the relevant submissions:
-> [Kasimbeg, Schneider, Eschenhagen, et al.
-> **Accelerating neural network training: An analysis of the AlgoPerf competition**
+> [Kasimbeg, Schneider, Eschenhagen, et al.
> **Accelerating neural network training: An analysis of the AlgoPerf competition**
> ICLR 2025](https://openreview.net/forum?id=CtM5xjRSfm)
```bibtex
From 277674c64cfdf671c1721c5e32feb8219bc19a2c Mon Sep 17 00:00:00 2001
From: Frank Schneider
Date: Wed, 18 Jun 2025 16:54:06 +0200
Subject: [PATCH 099/123] Update style testing with ruff
---
docs/CONTRIBUTING.md | 42 +++++++++++++++---------------------------
1 file changed, 15 insertions(+), 27 deletions(-)
diff --git a/docs/CONTRIBUTING.md b/docs/CONTRIBUTING.md
index 7778030dc..e9918e14c 100644
--- a/docs/CONTRIBUTING.md
+++ b/docs/CONTRIBUTING.md
@@ -179,13 +179,13 @@ docker run -t -d \
To find the container IDs of running containers
```bash
-docker ps
+docker ps
```
To see the logging output
```bash
-docker logs
+docker logs
```
To enter a bash session in the container
@@ -209,7 +209,7 @@ docker run -t -d \
--gpus all \
--ipc=host \
\
---keep_container_alive true
+--keep_container_alive true
```
## Submitting PRs
@@ -222,38 +222,26 @@ We run tests with GitHub Actions, configured in the [.github/workflows](.github/
### Style Testing
-We run yapf and linting tests on PRs. You can view and fix offending errors with these instructions.
-
+We run formatting and linting tests via ruff on PRs. You can view and fix offending errors with these instructions.
To run the below commands, use the versions installed via `pip install -e '.[dev]'`.
-To automatically fix formatting errors, run the following (*WARNING:* this will edit your code, so it is suggested to make a git commit first!):
+To check whether your code is **formatted** correctly, run the following:
```bash
-yapf -i -r -vv -p algoperf datasets prize_qualification_baselines reference_algorithms tests *.py
+ruff format --check
```
-To sort all import orderings, run the following:
+To automatically fix formatting errors you can run `ruff format`, without the `--check` flag.
+(**WARNING**: this will edit your code, so it is suggested to make a git commit first!)
-```bash
-isort .
-```
-
-To just print out all offending import orderings, run the following:
+To check whether your code is **linted** correctly, run the following:
```bash
-isort . --check --diff
+ruff check
```
-To print out all offending pylint issues, run the following:
-
-```bash
-pylint algoperf
-pylint datasets
-pylint prize_qualification_baselines
-pylint reference_algorithms
-pylint submission_runner.py
-pylint tests
-```
+To automatically fix linting errors you can run `ruff check --fix`, with the additional `--fix` flag.
+(**WARNING**: this will edit your code, so it is suggested to make a git commit first!)
### Unit and Integration Tests
@@ -270,9 +258,9 @@ To run a regression test:
1. Build and upload latest Docker images from dev branch.
- ```bash
- bash ~/algorithmic-efficiency/docker/build_docker_images.sh -b dev
- ```
+ ```bash
+ bash ~/algorithmic-efficiency/docker/build_docker_images.sh -b dev
+ ```
2. Turn on the self-hosted runner.
3. Run the self-hosted runner application for the runner to accept jobs.
From 830d2c2506e2f85d737e0cb9ce4f433b7d898923 Mon Sep 17 00:00:00 2001
From: Frank Schneider
Date: Mon, 23 Jun 2025 10:16:03 +0200
Subject: [PATCH 100/123] Format submission_runner
---
submission_runner.py | 772 ++++++++++++++++++++++++-------------------
1 file changed, 426 insertions(+), 346 deletions(-)
diff --git a/submission_runner.py b/submission_runner.py
index 221a7c21d..8dc8589cb 100644
--- a/submission_runner.py
+++ b/submission_runner.py
@@ -17,41 +17,37 @@
import datetime
import gc
import importlib
-from inspect import signature
import itertools
import json
import os
import struct
import time
+from inspect import signature
from types import MappingProxyType
from typing import Any, Dict, Optional, Tuple
-from absl import app
-from absl import flags
-from absl import logging
import jax
import tensorflow as tf
import torch
import torch.distributed as dist
+from absl import app, flags, logging
# Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make
# it unavailable to JAX.
tf.config.set_visible_devices([], 'GPU')
-from algoperf import checkpoint_utils
-from algoperf import halton
-from algoperf import logger_utils
-from algoperf import random_utils as prng
-from algoperf import spec
-from algoperf.profiler import PassThroughProfiler
-from algoperf.profiler import Profiler
-from algoperf.pytorch_utils import pytorch_init
-from algoperf.pytorch_utils import pytorch_setup
-from algoperf.pytorch_utils import sync_ddp_time
-from algoperf.workloads import workloads
+from algoperf import checkpoint_utils, halton, logger_utils, spec # noqa: E402
+from algoperf import random_utils as prng # noqa: E402
+from algoperf.profiler import PassThroughProfiler, Profiler # noqa: E402
+from algoperf.pytorch_utils import ( # noqa: E402
+ pytorch_init,
+ pytorch_setup,
+ sync_ddp_time,
+)
+from algoperf.workloads import workloads # noqa: E402
# Environment variables
-os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Disables tensorRT, cuda warnings.
+os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Disables tensorRT, cuda warnings.
# disable only for deepspeech if it works fine for other workloads
os.environ['XLA_FLAGS'] = '--xla_gpu_enable_triton_gemm=false'
@@ -62,106 +58,121 @@
WORKLOADS = workloads.WORKLOADS
flags.DEFINE_string(
- 'submission_path',
- None,
- 'The relative path of the Python file containing submission functions. '
- 'NOTE: the submission dir must have an __init__.py file!')
+ 'submission_path',
+ None,
+ 'The relative path of the Python file containing submission functions. '
+ 'NOTE: the submission dir must have an __init__.py file!',
+)
flags.DEFINE_string(
- 'workload',
- None,
- help=f'The name of the workload to run.\n Choices: {list(WORKLOADS.keys())}'
+ 'workload',
+ None,
+ help=f'The name of the workload to run.\n Choices: {list(WORKLOADS.keys())}',
)
flags.DEFINE_enum(
- 'tuning_ruleset',
- 'external',
- enum_values=['external', 'self'],
- help='Which tuning ruleset to use.')
+ 'tuning_ruleset',
+ 'external',
+ enum_values=['external', 'self'],
+ help='Which tuning ruleset to use.',
+)
flags.DEFINE_string(
- 'tuning_search_space',
- None,
- 'The path to the JSON file describing the external tuning search space.')
-flags.DEFINE_integer('num_tuning_trials',
- 1,
- 'The number of external hyperparameter trials to run.')
+ 'tuning_search_space',
+ None,
+ 'The path to the JSON file describing the external tuning search space.',
+)
+flags.DEFINE_integer(
+ 'num_tuning_trials', 1, 'The number of external hyperparameter trials to run.'
+)
flags.DEFINE_string('data_dir', '~/data', 'Dataset location.')
-flags.DEFINE_string('imagenet_v2_data_dir',
- None,
- 'Dataset location for ImageNet-v2.')
-flags.DEFINE_string('librispeech_tokenizer_vocab_path',
- '',
- 'Location to librispeech tokenizer.')
+flags.DEFINE_string(
+ 'imagenet_v2_data_dir', None, 'Dataset location for ImageNet-v2.'
+)
+flags.DEFINE_string(
+ 'librispeech_tokenizer_vocab_path', '', 'Location to librispeech tokenizer.'
+)
flags.DEFINE_enum(
- 'framework',
- None,
- enum_values=['jax', 'pytorch'],
- help='Whether to use Jax or Pytorch for the submission. Controls among '
- 'other things if the Jax or Numpy RNG library is used for RNG.')
+ 'framework',
+ None,
+ enum_values=['jax', 'pytorch'],
+ help='Whether to use Jax or Pytorch for the submission. Controls among '
+ 'other things if the Jax or Numpy RNG library is used for RNG.',
+)
flags.DEFINE_boolean(
- 'torch_compile',
- True,
- 'Whether to use `torch.compile` to JIT-compile PyTorch code. '
- 'This will only take effect when `framework`==pytorch.')
+ 'torch_compile',
+ True,
+ 'Whether to use `torch.compile` to JIT-compile PyTorch code. '
+ 'This will only take effect when `framework`==pytorch.',
+)
flags.DEFINE_string(
- 'experiment_dir',
- None,
- 'The root directory to store all experiments. '
- 'It is required and the directory should have '
- 'an absolute path rather than a relative path.')
+ 'experiment_dir',
+ None,
+ 'The root directory to store all experiments. '
+ 'It is required and the directory should have '
+ 'an absolute path rather than a relative path.',
+)
flags.DEFINE_string('experiment_name', None, 'Name of the experiment.')
flags.DEFINE_boolean(
- 'save_checkpoints',
- True,
- 'Whether or not to save checkpoints of the model and optimizer '
- 'at every eval and after training.')
+ 'save_checkpoints',
+ True,
+ 'Whether or not to save checkpoints of the model and optimizer '
+ 'at every eval and after training.',
+)
+flags.DEFINE_boolean(
+ 'save_intermediate_checkpoints',
+ True,
+ 'Whether to save any intermediate checkpoints. '
+ 'If False, it will only keep the latest checkpoint.',
+)
flags.DEFINE_boolean(
- 'save_intermediate_checkpoints',
- True,
- 'Whether to save any intermediate checkpoints. '
- 'If False, it will only keep the latest checkpoint.')
-flags.DEFINE_boolean('resume_last_run',
- None,
- 'Whether to resume the experiment from its last run.')
+ 'resume_last_run', None, 'Whether to resume the experiment from its last run.'
+)
flags.DEFINE_boolean(
- 'append_timestamp',
- False,
- 'If True, the current datetime will be appended to the experiment name. '
- 'Useful for guaranteeing a unique experiment dir for new runs.')
-flags.DEFINE_boolean('use_wandb',
- False,
- 'Whether to use Weights & Biases logging.')
+ 'append_timestamp',
+ False,
+ 'If True, the current datetime will be appended to the experiment name. '
+ 'Useful for guaranteeing a unique experiment dir for new runs.',
+)
+flags.DEFINE_boolean(
+ 'use_wandb', False, 'Whether to use Weights & Biases logging.'
+)
flags.DEFINE_boolean('profile', False, 'Whether to produce profiling output.')
-flags.DEFINE_integer('max_global_steps',
- None,
- 'Maximum number of update steps.')
+flags.DEFINE_integer(
+ 'max_global_steps', None, 'Maximum number of update steps.'
+)
flags.DEFINE_boolean(
- 'overwrite',
- False,
- 'Whether to overwrite the experiment with identical experiment_dir and'
- 'experiment_name.')
+ 'overwrite',
+ False,
+ 'Whether to overwrite the experiment with identical experiment_dir and'
+ 'experiment_name.',
+)
flags.DEFINE_integer(
- 'hparam_start_index',
- None,
- 'Start index to slice set of hyperparameters in tuning search space.')
+ 'hparam_start_index',
+ None,
+ 'Start index to slice set of hyperparameters in tuning search space.',
+)
flags.DEFINE_integer(
- 'hparam_end_index',
- None,
- 'End index to slice set of hyperparameters in tuning search space.')
+ 'hparam_end_index',
+ None,
+ 'End index to slice set of hyperparameters in tuning search space.',
+)
flags.DEFINE_integer(
- 'rng_seed',
- None,
- 'Value of rng seed. If None, a random seed will'
- 'be generated from hardware.')
-flags.DEFINE_boolean('set_pytorch_max_split_size',
- False,
- 'If true, set pytorch max_split_size_mb to 256')
+ 'rng_seed',
+ None,
+ 'Value of rng seed. If None, a random seed willbe generated from hardware.',
+)
+flags.DEFINE_boolean(
+ 'set_pytorch_max_split_size',
+ False,
+ 'If true, set pytorch max_split_size_mb to 256',
+)
flags.DEFINE_integer(
- 'pytorch_eval_num_workers',
- 0,
- 'Number of workers for ImageNet PyTorch evaluation data loaders.'
- 'WARNING: Setting pytorch_eval_num_workers != 0, will result '
- 'in incorrect evals currently, see issues/732.')
+ 'pytorch_eval_num_workers',
+ 0,
+ 'Number of workers for ImageNet PyTorch evaluation data loaders.'
+ 'WARNING: Setting pytorch_eval_num_workers != 0, will result '
+ 'in incorrect evals currently, see issues/732.',
+)
FLAGS = flags.FLAGS
USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup()
@@ -193,23 +204,23 @@ def _reset_cuda_mem():
def train_once(
- workload: spec.Workload,
- workload_name: str,
- global_batch_size: int,
- global_eval_batch_size: int,
- data_dir: str,
- imagenet_v2_data_dir: str,
- init_optimizer_state: spec.InitOptimizerFn,
- update_params: spec.UpdateParamsFn,
- data_selection: spec.DataSelectionFn,
- prepare_for_eval: Optional[spec.PrepareForEvalFn],
- hyperparameters: Optional[spec.Hyperparameters],
- rng_seed: int,
- rng: spec.RandomState,
- profiler: Profiler,
- max_global_steps: int = None,
- log_dir: Optional[str] = None,
- save_checkpoints: Optional[bool] = True
+ workload: spec.Workload,
+ workload_name: str,
+ global_batch_size: int,
+ global_eval_batch_size: int,
+ data_dir: str,
+ imagenet_v2_data_dir: str,
+ init_optimizer_state: spec.InitOptimizerFn,
+ update_params: spec.UpdateParamsFn,
+ data_selection: spec.DataSelectionFn,
+ prepare_for_eval: Optional[spec.PrepareForEvalFn],
+ hyperparameters: Optional[spec.Hyperparameters],
+ rng_seed: int,
+ rng: spec.RandomState,
+ profiler: Profiler,
+ max_global_steps: int = None,
+ log_dir: Optional[str] = None,
+ save_checkpoints: Optional[bool] = True,
) -> Tuple[spec.Timing, Dict[str, Any]]:
_reset_cuda_mem()
data_rng, opt_init_rng, model_init_rng, rng = prng.split(rng, 4)
@@ -222,40 +233,44 @@ def train_once(
workload.eval_num_workers = FLAGS.pytorch_eval_num_workers
with profiler.profile('Initializing dataset'):
input_queue = workload._build_input_queue(
- data_rng,
- 'train',
- data_dir=data_dir,
- global_batch_size=global_batch_size)
+ data_rng, 'train', data_dir=data_dir, global_batch_size=global_batch_size
+ )
logging.info('Initializing model.')
with profiler.profile('Initializing model'):
model_params, model_state = workload.init_model_fn(model_init_rng)
if FLAGS.framework == 'pytorch' and FLAGS.torch_compile:
compile_error_workloads = [
- 'librispeech_conformer',
- 'ogbg',
- 'criteo1tb',
- 'imagenet_vit',
- 'librispeech_deepspeech'
+ 'librispeech_conformer',
+ 'ogbg',
+ 'criteo1tb',
+ 'imagenet_vit',
+ 'librispeech_deepspeech',
]
eager_backend_workloads = []
aot_eager_backend_workloads = []
loss_compilation_workloads = [
- 'fastmri', 'librispeech_deepspeech', 'ogbg', 'wmt'
+ 'fastmri',
+ 'librispeech_deepspeech',
+ 'ogbg',
+ 'wmt',
]
base_workload = workloads.get_base_workload_name(workload_name)
if base_workload in compile_error_workloads:
logging.warning(
- 'These workloads cannot be fully compiled under current '
- 'PyTorch version. Proceeding without `torch.compile`.')
+ 'These workloads cannot be fully compiled under current '
+ 'PyTorch version. Proceeding without `torch.compile`.'
+ )
elif base_workload in eager_backend_workloads:
logging.warning(
- 'These workloads cannot be fully compiled under current '
- 'PyTorch version. Proceeding with `backend=eager`.')
+ 'These workloads cannot be fully compiled under current '
+ 'PyTorch version. Proceeding with `backend=eager`.'
+ )
model_params = torch.compile(model_params, backend='eager')
elif base_workload in aot_eager_backend_workloads:
logging.warning(
- 'These workloads cannot be fully compiled under current '
- 'PyTorch version. Proceeding with `backend=aot_eager`.')
+ 'These workloads cannot be fully compiled under current '
+ 'PyTorch version. Proceeding with `backend=aot_eager`.'
+ )
model_params = torch.compile(model_params, backend='aot_eager')
else:
logging.info('Performing `torch.compile`.')
@@ -264,11 +279,9 @@ def train_once(
workload.loss_fn = torch.compile(workload.loss_fn)
logging.info('Initializing optimizer.')
with profiler.profile('Initializing optimizer'):
- optimizer_state = init_optimizer_state(workload,
- model_params,
- model_state,
- hyperparameters,
- opt_init_rng)
+ optimizer_state = init_optimizer_state(
+ workload, model_params, model_state, hyperparameters, opt_init_rng
+ )
logging.info('Initializing metrics bundle.')
# Check if 'train_state' is in the function signature
@@ -276,15 +289,15 @@ def train_once(
# Bookkeeping.
train_state = {
- 'validation_goal_reached': False,
- 'test_goal_reached': False,
- 'is_time_remaining': True,
- 'last_eval_time': 0,
- 'training_complete': False,
- 'accumulated_submission_time': 0,
- 'accumulated_eval_time': 0,
- 'accumulated_logging_time': 0,
- 'last_step_end_time': None,
+ 'validation_goal_reached': False,
+ 'test_goal_reached': False,
+ 'is_time_remaining': True,
+ 'last_eval_time': 0,
+ 'training_complete': False,
+ 'accumulated_submission_time': 0,
+ 'accumulated_eval_time': 0,
+ 'accumulated_logging_time': 0,
+ 'last_step_end_time': None,
}
global_step = 0
eval_results = []
@@ -294,22 +307,25 @@ def train_once(
logging.info('Initializing checkpoint and logger.')
if log_dir is not None:
# If the checkpoint exists, load from the checkpoint.
- (optimizer_state,
- model_params,
- model_state,
- train_state,
- eval_results,
- global_step,
- preemption_count) = checkpoint_utils.maybe_restore_checkpoint(
- FLAGS.framework,
- optimizer_state,
- model_params,
- model_state,
- train_state,
- eval_results,
- global_step,
- preemption_count,
- checkpoint_dir=log_dir)
+ (
+ optimizer_state,
+ model_params,
+ model_state,
+ train_state,
+ eval_results,
+ global_step,
+ preemption_count,
+ ) = checkpoint_utils.maybe_restore_checkpoint(
+ FLAGS.framework,
+ optimizer_state,
+ model_params,
+ model_state,
+ train_state,
+ eval_results,
+ global_step,
+ preemption_count,
+ checkpoint_dir=log_dir,
+ )
meta_file_name = os.path.join(log_dir, f'meta_data_{preemption_count}.json')
logging.info(f'Saving meta data to {meta_file_name}.')
meta_data = logger_utils.get_meta_data(workload, rng_seed)
@@ -319,9 +335,9 @@ def train_once(
logger_utils.write_json(flag_file_name, flags.FLAGS.flag_values_dict())
metrics_logger = None
if RANK == 0:
- metrics_logger = logger_utils.set_up_loggers(log_dir,
- flags.FLAGS,
- hyperparameters)
+ metrics_logger = logger_utils.set_up_loggers(
+ log_dir, flags.FLAGS, hyperparameters
+ )
workload.attach_metrics_logger(metrics_logger)
global_start_time = get_time()
@@ -329,42 +345,50 @@ def train_once(
logging.info('Starting training loop.')
goals_reached = (
- train_state['validation_goal_reached'] and
- train_state['test_goal_reached'])
- while train_state['is_time_remaining'] and \
- not goals_reached and \
- not train_state['training_complete']:
-
+ train_state['validation_goal_reached'] and train_state['test_goal_reached']
+ )
+ while (
+ train_state['is_time_remaining']
+ and not goals_reached
+ and not train_state['training_complete']
+ ):
step_rng = prng.fold_in(rng, global_step)
- data_select_rng, update_rng, prep_eval_rng, eval_rng = \
- prng.split(step_rng, 4)
+ data_select_rng, update_rng, prep_eval_rng, eval_rng = prng.split(
+ step_rng, 4
+ )
with profiler.profile('Data selection'):
- batch = data_selection(workload,
- input_queue,
- optimizer_state,
- model_params,
- model_state,
- hyperparameters,
- global_step,
- data_select_rng)
+ batch = data_selection(
+ workload,
+ input_queue,
+ optimizer_state,
+ model_params,
+ model_state,
+ hyperparameters,
+ global_step,
+ data_select_rng,
+ )
try:
with profiler.profile('Update parameters'):
optimizer_state, model_params, model_state = update_params(
- workload=workload,
- current_param_container=model_params,
- current_params_types=workload.model_params_types,
- model_state=model_state,
- hyperparameters=hyperparameters,
- batch=batch,
- loss_type=workload.loss_type,
- optimizer_state=optimizer_state,
- eval_results=eval_results,
- global_step=global_step,
- rng=update_rng,
- **({'train_state': MappingProxyType(train_state)}
- if needs_train_state else {}))
+ workload=workload,
+ current_param_container=model_params,
+ current_params_types=workload.model_params_types,
+ model_state=model_state,
+ hyperparameters=hyperparameters,
+ batch=batch,
+ loss_type=workload.loss_type,
+ optimizer_state=optimizer_state,
+ eval_results=eval_results,
+ global_step=global_step,
+ rng=update_rng,
+ **(
+ {'train_state': MappingProxyType(train_state)}
+ if needs_train_state
+ else {}
+ ),
+ )
except spec.TrainingCompleteError:
train_state['training_complete'] = True
global_step += 1
@@ -374,121 +398,139 @@ def train_once(
train_step_end_time = get_time()
train_state['accumulated_submission_time'] += (
- train_step_end_time - train_state['last_step_end_time'])
+ train_step_end_time - train_state['last_step_end_time']
+ )
# Check if submission is eligible for an untimed eval.
- if ((train_step_end_time - train_state['last_eval_time']) >=
- workload.eval_period_time_sec or train_state['training_complete']):
-
+ if (
+ train_step_end_time - train_state['last_eval_time']
+ ) >= workload.eval_period_time_sec or train_state['training_complete']:
# Prepare for evaluation (timed).
if prepare_for_eval is not None:
-
with profiler.profile('Prepare for eval'):
del batch
prepare_for_eval_start_time = get_time()
optimizer_state, model_params, model_state = prepare_for_eval(
- workload=workload,
- current_param_container=model_params,
- current_params_types=workload.model_params_types,
- model_state=model_state,
- hyperparameters=hyperparameters,
- loss_type=workload.loss_type,
- optimizer_state=optimizer_state,
- eval_results=eval_results,
- global_step=global_step,
- rng=prep_eval_rng)
+ workload=workload,
+ current_param_container=model_params,
+ current_params_types=workload.model_params_types,
+ model_state=model_state,
+ hyperparameters=hyperparameters,
+ loss_type=workload.loss_type,
+ optimizer_state=optimizer_state,
+ eval_results=eval_results,
+ global_step=global_step,
+ rng=prep_eval_rng,
+ )
prepare_for_eval_end_time = get_time()
# Update sumbission time.
train_state['accumulated_submission_time'] += (
- prepare_for_eval_end_time - prepare_for_eval_start_time)
+ prepare_for_eval_end_time - prepare_for_eval_start_time
+ )
# Check if time is remaining,
# use 1.5x the runtime budget for the self-tuning ruleset.
max_allowed_runtime_sec = (
- workload.max_allowed_runtime_sec if FLAGS.tuning_ruleset == 'external'
- else 1.5 * workload.max_allowed_runtime_sec)
+ workload.max_allowed_runtime_sec
+ if FLAGS.tuning_ruleset == 'external'
+ else 1.5 * workload.max_allowed_runtime_sec
+ )
train_state['is_time_remaining'] = (
- train_state['accumulated_submission_time'] < max_allowed_runtime_sec)
+ train_state['accumulated_submission_time'] < max_allowed_runtime_sec
+ )
# Eval if time is remaining (untimed).
if train_state['is_time_remaining']:
-
with profiler.profile('Evaluation'):
_reset_cuda_mem()
try:
eval_start_time = get_time()
- latest_eval_result = workload.eval_model(global_eval_batch_size,
- model_params,
- model_state,
- eval_rng,
- data_dir,
- imagenet_v2_data_dir,
- global_step)
+ latest_eval_result = workload.eval_model(
+ global_eval_batch_size,
+ model_params,
+ model_state,
+ eval_rng,
+ data_dir,
+ imagenet_v2_data_dir,
+ global_step,
+ )
# Check if targets reached.
# Note that this is one of the stopping conditions for the length of
# a training run. To score the run we only consider the time
# to validation target retrospectively.
train_state['validation_goal_reached'] = (
- workload.has_reached_validation_target(latest_eval_result) or
- train_state['validation_goal_reached'])
+ workload.has_reached_validation_target(latest_eval_result)
+ or train_state['validation_goal_reached']
+ )
train_state['test_goal_reached'] = (
- workload.has_reached_test_target(latest_eval_result) or
- train_state['test_goal_reached'])
+ workload.has_reached_test_target(latest_eval_result)
+ or train_state['test_goal_reached']
+ )
goals_reached = (
- train_state['validation_goal_reached'] and
- train_state['test_goal_reached'])
+ train_state['validation_goal_reached']
+ and train_state['test_goal_reached']
+ )
# Save last eval time.
eval_end_time = get_time()
train_state['last_eval_time'] = eval_end_time
# Accumulate eval time.
- train_state[
- 'accumulated_eval_time'] += eval_end_time - eval_start_time
+ train_state['accumulated_eval_time'] += (
+ eval_end_time - eval_start_time
+ )
# Add times to eval results for logging.
- latest_eval_result['score'] = (
- train_state['accumulated_submission_time'])
- latest_eval_result[
- 'total_duration'] = eval_end_time - global_start_time
+ latest_eval_result['score'] = train_state[
+ 'accumulated_submission_time'
+ ]
+ latest_eval_result['total_duration'] = (
+ eval_end_time - global_start_time
+ )
latest_eval_result['accumulated_submission_time'] = train_state[
- 'accumulated_submission_time']
+ 'accumulated_submission_time'
+ ]
latest_eval_result['accumulated_eval_time'] = train_state[
- 'accumulated_eval_time']
+ 'accumulated_eval_time'
+ ]
latest_eval_result['accumulated_logging_time'] = train_state[
- 'accumulated_logging_time']
+ 'accumulated_logging_time'
+ ]
time_since_start = latest_eval_result['total_duration']
- logging.info(f'Time since start: {time_since_start:.2f}s, '
- f'\tStep: {global_step}, \t{latest_eval_result}')
+ logging.info(
+ f'Time since start: {time_since_start:.2f}s, '
+ f'\tStep: {global_step}, \t{latest_eval_result}'
+ )
eval_results.append((global_step, latest_eval_result))
logging_start_time = get_time()
if log_dir is not None and RANK == 0:
metrics_logger.append_scalar_metrics(
- latest_eval_result,
- global_step=global_step,
- preemption_count=preemption_count,
- is_eval=True,
+ latest_eval_result,
+ global_step=global_step,
+ preemption_count=preemption_count,
+ is_eval=True,
)
if save_checkpoints:
checkpoint_utils.save_checkpoint(
- framework=FLAGS.framework,
- optimizer_state=optimizer_state,
- model_params=model_params,
- model_state=model_state,
- train_state=train_state,
- eval_results=eval_results,
- global_step=global_step,
- preemption_count=preemption_count,
- checkpoint_dir=log_dir,
- save_intermediate_checkpoints=FLAGS
- .save_intermediate_checkpoints)
+ framework=FLAGS.framework,
+ optimizer_state=optimizer_state,
+ model_params=model_params,
+ model_state=model_state,
+ train_state=train_state,
+ eval_results=eval_results,
+ global_step=global_step,
+ preemption_count=preemption_count,
+ checkpoint_dir=log_dir,
+ save_intermediate_checkpoints=FLAGS.save_intermediate_checkpoints,
+ )
logging_end_time = get_time()
train_state['accumulated_logging_time'] += (
- logging_end_time - logging_start_time)
+ logging_end_time - logging_start_time
+ )
_reset_cuda_mem()
@@ -496,8 +538,9 @@ def train_once(
logging.exception(f'Eval step {global_step} error.\n')
if 'out of memory' in str(e):
logging.warning(
- 'Error: GPU out of memory during eval during step '
- f'{global_step}, error : {str(e)}.')
+ 'Error: GPU out of memory during eval during step '
+ f'{global_step}, error : {str(e)}.'
+ )
_reset_cuda_mem()
train_state['last_step_end_time'] = get_time()
@@ -506,41 +549,45 @@ def train_once(
if log_dir is not None and RANK == 0:
metrics_logger.append_scalar_metrics(
- {'score': train_state['accumulated_submission_time']},
- global_step=global_step,
- preemption_count=preemption_count)
+ {'score': train_state['accumulated_submission_time']},
+ global_step=global_step,
+ preemption_count=preemption_count,
+ )
metrics_logger.finish()
if save_checkpoints:
checkpoint_utils.save_checkpoint(
- framework=FLAGS.framework,
- optimizer_state=optimizer_state,
- model_params=model_params,
- model_state=model_state,
- train_state=train_state,
- eval_results=eval_results,
- global_step=global_step,
- preemption_count=preemption_count,
- checkpoint_dir=log_dir,
- save_intermediate_checkpoints=FLAGS.save_intermediate_checkpoints)
+ framework=FLAGS.framework,
+ optimizer_state=optimizer_state,
+ model_params=model_params,
+ model_state=model_state,
+ train_state=train_state,
+ eval_results=eval_results,
+ global_step=global_step,
+ preemption_count=preemption_count,
+ checkpoint_dir=log_dir,
+ save_intermediate_checkpoints=FLAGS.save_intermediate_checkpoints,
+ )
return train_state['accumulated_submission_time'], metrics
-def score_submission_on_workload(workload: spec.Workload,
- workload_name: str,
- submission_path: str,
- data_dir: str,
- tuning_ruleset: str,
- profiler: Optional[Profiler] = None,
- max_global_steps: Optional[int] = None,
- imagenet_v2_data_dir: Optional[str] = None,
- tuning_search_space: Optional[str] = None,
- num_tuning_trials: Optional[int] = None,
- log_dir: Optional[str] = None,
- save_checkpoints: Optional[bool] = True,
- hparam_start_index: Optional[bool] = None,
- hparam_end_index: Optional[bool] = None,
- rng_seed: Optional[int] = None):
+def score_submission_on_workload(
+ workload: spec.Workload,
+ workload_name: str,
+ submission_path: str,
+ data_dir: str,
+ tuning_ruleset: str,
+ profiler: Optional[Profiler] = None,
+ max_global_steps: Optional[int] = None,
+ imagenet_v2_data_dir: Optional[str] = None,
+ tuning_search_space: Optional[str] = None,
+ num_tuning_trials: Optional[int] = None,
+ log_dir: Optional[str] = None,
+ save_checkpoints: Optional[bool] = True,
+ hparam_start_index: Optional[bool] = None,
+ hparam_end_index: Optional[bool] = None,
+ rng_seed: Optional[int] = None,
+):
# Expand paths because '~' may not be recognized
data_dir = os.path.expanduser(data_dir)
if imagenet_v2_data_dir:
@@ -564,18 +611,21 @@ def score_submission_on_workload(workload: spec.Workload,
n_gpus = max(N_GPUS, jax.local_device_count())
if global_batch_size % n_gpus != 0:
raise ValueError(
- f'The global batch size ({global_batch_size}) has to be divisible by '
- f'the number of GPUs ({n_gpus}).')
+ f'The global batch size ({global_batch_size}) has to be divisible by '
+ f'the number of GPUs ({n_gpus}).'
+ )
if hasattr(submission_module, 'get_eval_batch_size'):
# If the user specifies the eval batch size, use the provided one.
global_eval_batch_size = submission_module.get_eval_batch_size(
- workload_name)
+ workload_name
+ )
else:
global_eval_batch_size = workload.eval_batch_size
if global_eval_batch_size % n_gpus != 0:
raise ValueError(
- f'The global eval batch size ({global_eval_batch_size}) has to be '
- f'divisible by the number of GPUs ({n_gpus}).')
+ f'The global eval batch size ({global_eval_batch_size}) has to be '
+ f'divisible by the number of GPUs ({n_gpus}).'
+ )
if tuning_ruleset == 'external':
# If the submission runner is responsible for hyperparameter tuning, load in
@@ -583,15 +633,18 @@ def score_submission_on_workload(workload: spec.Workload,
# settings from it.
if tuning_search_space is None:
raise ValueError(
- 'Must provide a tuning search space JSON file when using external '
- 'tuning.')
+ 'Must provide a tuning search space JSON file when using external '
+ 'tuning.'
+ )
with open(tuning_search_space, 'r', encoding='UTF-8') as search_space_file:
tuning_search_space = halton.generate_search(
- json.load(search_space_file), num_tuning_trials)
+ json.load(search_space_file), num_tuning_trials
+ )
all_timings = {}
all_metrics = {}
tuning_search_space_iter = itertools.islice(
- enumerate(tuning_search_space), hparam_start_index, hparam_end_index)
+ enumerate(tuning_search_space), hparam_start_index, hparam_end_index
+ )
for hi, hyperparameters in tuning_search_space_iter:
# Generate a new seed from hardware sources of randomness for each trial.
if not rng_seed:
@@ -615,25 +668,31 @@ def score_submission_on_workload(workload: spec.Workload,
# If existing hyperparameter exists, use saved
# hyperparameters for consistency.
- hyperparameters = logger_utils.write_hparams(hyperparameters,
- tuning_dir_name)
+ hyperparameters = logger_utils.write_hparams(
+ hyperparameters, tuning_dir_name
+ )
tuning_search_space[hi] = hyperparameters
with profiler.profile('Train'):
- timing, metrics = train_once(workload, workload_name,
- global_batch_size,
- global_eval_batch_size,
- data_dir, imagenet_v2_data_dir,
- init_optimizer_state,
- update_params, data_selection,
- prepare_for_eval,
- hyperparameters,
- rng_seed,
- rng,
- profiler,
- max_global_steps,
- tuning_dir_name,
- save_checkpoints=save_checkpoints,)
+ timing, metrics = train_once(
+ workload,
+ workload_name,
+ global_batch_size,
+ global_eval_batch_size,
+ data_dir,
+ imagenet_v2_data_dir,
+ init_optimizer_state,
+ update_params,
+ data_selection,
+ prepare_for_eval,
+ hyperparameters,
+ rng_seed,
+ rng,
+ profiler,
+ max_global_steps,
+ tuning_dir_name,
+ save_checkpoints=save_checkpoints,
+ )
all_timings[hi] = timing
all_metrics[hi] = metrics
logging.info(f'Tuning trial {hi + 1}/{num_tuning_trials}')
@@ -647,7 +706,8 @@ def score_submission_on_workload(workload: spec.Workload,
else:
if tuning_search_space is not None:
raise ValueError(
- 'Cannot provide a tuning search space when using self tuning.')
+ 'Cannot provide a tuning search space when using self tuning.'
+ )
if not rng_seed:
rng_seed = struct.unpack('q', os.urandom(8))[0]
rng = prng.PRNGKey(rng_seed)
@@ -659,11 +719,24 @@ def score_submission_on_workload(workload: spec.Workload,
logger_utils.makedir(log_dir)
with profiler.profile('Train'):
score, _ = train_once(
- workload, workload_name, global_batch_size, global_eval_batch_size,
- data_dir, imagenet_v2_data_dir,
- init_optimizer_state, update_params, data_selection, prepare_for_eval,
- None, rng_seed, rng, profiler, max_global_steps, log_dir,
- save_checkpoints=save_checkpoints)
+ workload,
+ workload_name,
+ global_batch_size,
+ global_eval_batch_size,
+ data_dir,
+ imagenet_v2_data_dir,
+ init_optimizer_state,
+ update_params,
+ data_selection,
+ prepare_for_eval,
+ None,
+ rng_seed,
+ rng,
+ profiler,
+ max_global_steps,
+ log_dir,
+ save_checkpoints=save_checkpoints,
+ )
return score
@@ -687,59 +760,66 @@ def main(_):
# TODO: remove once issue resolved.
if FLAGS.pytorch_eval_num_workers != 0:
logging.warning(
- 'WARNING: Setting pytorch_eval_num_workers != 0, will result '
- 'in incorrect evals currently, see issues/732.')
+ 'WARNING: Setting pytorch_eval_num_workers != 0, will result '
+ 'in incorrect evals currently, see issues/732.'
+ )
workload_metadata = WORKLOADS[FLAGS.workload]
if base_workload in [
- 'librispeech_conformer',
- 'librispeech_deepspeech',
- 'imagenet_vit',
- 'criteo1tb'
+ 'librispeech_conformer',
+ 'librispeech_deepspeech',
+ 'imagenet_vit',
+ 'criteo1tb',
]:
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.80'
# Extend path according to framework.
workload_metadata['workload_path'] = os.path.join(
- BASE_WORKLOADS_DIR,
- workload_metadata['workload_path'] + f'_{FLAGS.framework}',
- 'workload.py')
+ BASE_WORKLOADS_DIR,
+ workload_metadata['workload_path'] + f'_{FLAGS.framework}',
+ 'workload.py',
+ )
workload_init_kwargs = {}
if FLAGS.librispeech_tokenizer_vocab_path:
workload_init_kwargs['tokenizer_vocab_path'] = (
- FLAGS.librispeech_tokenizer_vocab_path)
+ FLAGS.librispeech_tokenizer_vocab_path
+ )
workload = workloads.import_workload(
- workload_path=workload_metadata['workload_path'],
- workload_class_name=workload_metadata['workload_class_name'],
- workload_init_kwargs=workload_init_kwargs)
+ workload_path=workload_metadata['workload_path'],
+ workload_class_name=workload_metadata['workload_class_name'],
+ workload_init_kwargs=workload_init_kwargs,
+ )
experiment_name = FLAGS.experiment_name
if experiment_name and FLAGS.append_timestamp:
experiment_name += datetime.datetime.now().strftime('-%Y-%m-%d-%H-%M-%S')
- logging_dir_path = logger_utils.get_log_dir(FLAGS.experiment_dir,
- FLAGS.workload,
- FLAGS.framework,
- experiment_name,
- FLAGS.resume_last_run,
- FLAGS.overwrite)
+ logging_dir_path = logger_utils.get_log_dir(
+ FLAGS.experiment_dir,
+ FLAGS.workload,
+ FLAGS.framework,
+ experiment_name,
+ FLAGS.resume_last_run,
+ FLAGS.overwrite,
+ )
score = score_submission_on_workload(
- workload=workload,
- workload_name=FLAGS.workload,
- submission_path=FLAGS.submission_path,
- data_dir=FLAGS.data_dir,
- tuning_ruleset=FLAGS.tuning_ruleset,
- profiler=profiler,
- max_global_steps=FLAGS.max_global_steps,
- imagenet_v2_data_dir=FLAGS.imagenet_v2_data_dir,
- tuning_search_space=FLAGS.tuning_search_space,
- num_tuning_trials=FLAGS.num_tuning_trials,
- log_dir=logging_dir_path,
- save_checkpoints=FLAGS.save_checkpoints,
- hparam_start_index=FLAGS.hparam_start_index,
- hparam_end_index=FLAGS.hparam_end_index,
- rng_seed=FLAGS.rng_seed)
+ workload=workload,
+ workload_name=FLAGS.workload,
+ submission_path=FLAGS.submission_path,
+ data_dir=FLAGS.data_dir,
+ tuning_ruleset=FLAGS.tuning_ruleset,
+ profiler=profiler,
+ max_global_steps=FLAGS.max_global_steps,
+ imagenet_v2_data_dir=FLAGS.imagenet_v2_data_dir,
+ tuning_search_space=FLAGS.tuning_search_space,
+ num_tuning_trials=FLAGS.num_tuning_trials,
+ log_dir=logging_dir_path,
+ save_checkpoints=FLAGS.save_checkpoints,
+ hparam_start_index=FLAGS.hparam_start_index,
+ hparam_end_index=FLAGS.hparam_end_index,
+ rng_seed=FLAGS.rng_seed,
+ )
logging.info(f'Final {FLAGS.workload} score: {score}')
if FLAGS.profile:
From d84eddf2bcf6e4db1f7b47a13ea11affc5192f16 Mon Sep 17 00:00:00 2001
From: Frank Schneider
Date: Mon, 23 Jun 2025 10:16:12 +0200
Subject: [PATCH 101/123] Format submissions/
---
submissions/submission_checker.py | 38 ++++----
submissions/template/submission.py | 138 +++++++++++++++--------------
2 files changed, 95 insertions(+), 81 deletions(-)
diff --git a/submissions/submission_checker.py b/submissions/submission_checker.py
index ab657c0f0..fcc4e1faf 100644
--- a/submissions/submission_checker.py
+++ b/submissions/submission_checker.py
@@ -41,7 +41,8 @@ def _check_ruleset_subdirs(submission_dir):
contents = os.listdir(submission_dir)
if not ((EXTERNAL_TUNING in contents) or (SELF_TUNING in contents)):
logging.info(
- f'CHECK FAILED: {submission_dir} does not contain ruleset subdir.')
+ f'CHECK FAILED: {submission_dir} does not contain ruleset subdir.'
+ )
return False
return True
@@ -54,7 +55,7 @@ def _check_submission_module(submission_dir):
contents = os.listdir(os.path.join(root, submission_dir))
if SUBMISSION_MODULE not in contents:
logging.info(
- f'CHECK FAILED: {parent_dir}/{submission_dir} does not contain {SUBMISSION_MODULE}'
+ f'CHECK FAILED: {parent_dir}/{submission_dir} does not contain {SUBMISSION_MODULE}'
)
return False
return True
@@ -68,7 +69,7 @@ def _check_tuning_search_space_file(submission_dir):
contents = os.listdir(os.path.join(root, submission_dir))
if TUNING_SEARCH_SPACE_FILENAME not in contents:
logging.info(
- f'CHECK FAILED: {parent_dir}/{submission_dir} does not contain {TUNING_SEARCH_SPACE_FILENAME}'
+ f'CHECK FAILED: {parent_dir}/{submission_dir} does not contain {TUNING_SEARCH_SPACE_FILENAME}'
)
return False
return True
@@ -76,18 +77,22 @@ def _check_tuning_search_space_file(submission_dir):
def run_checks(submission_dir):
"""Top-level checker function.
- Call individual checkers from this function.
- """
+ Call individual checkers from this function.
+ """
logging.info('Running repository checks.')
# Execute checks
contains_ruleset_subdirs = _check_ruleset_subdirs(submission_dir)
contains_submission_module = _check_submission_module(submission_dir)
contains_tuning_search_space_file = _check_tuning_search_space_file(
- submission_dir)
+ submission_dir
+ )
- if not (contains_ruleset_subdirs and contains_submission_module and
- contains_tuning_search_space_file):
+ if not (
+ contains_ruleset_subdirs
+ and contains_submission_module
+ and contains_tuning_search_space_file
+ ):
logging.info('TESTS FAILED.')
return False
@@ -98,16 +103,17 @@ def run_checks(submission_dir):
def get_parser():
"""Parse commandline."""
parser = argparse.ArgumentParser(
- description='Checks for submission folder for AlgoPerf',)
+ description='Checks for submission folder for AlgoPerf',
+ )
parser.add_argument(
- 'folder',
- type=str,
- help='the folder for a submission package.',
+ 'folder',
+ type=str,
+ help='the folder for a submission package.',
)
parser.add_argument(
- '--log_output',
- type=str,
- default='submission_checker.log',
+ '--log_output',
+ type=str,
+ default='submission_checker.log',
)
return parser
@@ -118,7 +124,7 @@ def main():
logging.basicConfig(filename=args.log_output, level=logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler())
- formatter = logging.Formatter("%(levelname)s - %(message)s")
+ formatter = logging.Formatter('%(levelname)s - %(message)s')
logging.getLogger().handlers[0].setFormatter(formatter)
logging.getLogger().handlers[1].setFormatter(formatter)
diff --git a/submissions/template/submission.py b/submissions/template/submission.py
index 2269b7dbb..db6900afd 100644
--- a/submissions/template/submission.py
+++ b/submissions/template/submission.py
@@ -4,94 +4,102 @@
and https://github.com/mlcommons/algorithmic-efficiency/blob/main/DOCUMENTATION.md#disallowed-submissions
for guidelines.
"""
+
from typing import Any, Dict, Iterator, List, Optional, Tuple
from algoperf import spec
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a Nesterov optimizer and a learning rate schedule.
- Returns: spec.OptimizerState initialized optimizer state
- """
+ Returns: spec.OptimizerState initialized optimizer state
+ """
pass
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
+ """
+ Returns:
+ spec.OptimizerState: new optimizer state
+ spec.ParameterTypeTree: new params
+ new_model_state: new model state
"""
- Returns:
- spec.OptimizerState: new optimizer state
- spec.ParameterTypeTree: new params
- new_model_state: new model state
- """
pass
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
+ """
+ Returns:
+ new_optimizer_state
+ new_params
+ new_model_state
"""
- Returns:
- new_optimizer_state
- new_params
- new_model_state
- """
pass
def get_batch_size(workload_name):
"""
- Gets batch size for workload.
- Note that these batch sizes only apply during training and not during evals.
- Args:
- workload_name (str): Valid workload_name values are: "wmt", "ogbg",
- "criteo1tb", "fastmri", "imagenet_resnet", "imagenet_vit",
- "librispeech_deepspeech", "librispeech_conformer" or any of the
- variants.
- Returns:
- int: batch_size
- Raises:
- ValueError: If workload_name is not handled.
- """
+ Gets batch size for workload.
+ Note that these batch sizes only apply during training and not during evals.
+ Args:
+ workload_name (str): Valid workload_name values are: "wmt", "ogbg",
+ "criteo1tb", "fastmri", "imagenet_resnet", "imagenet_vit",
+ "librispeech_deepspeech", "librispeech_conformer" or any of the
+ variants.
+ Returns:
+ int: batch_size
+ Raises:
+ ValueError: If workload_name is not handled.
+ """
pass
-def data_selection(workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+def data_selection(
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
- Each element of the queue is a batch of training examples and labels.
- Tip:
- If you would just like the next batch from the input queue return next(input_queue).
+ Each element of the queue is a batch of training examples and labels.
+ Tip:
+ If you would just like the next batch from the input queue return next(input_queue).
- Returns:
- batch: next batch of input data
- """
+ Returns:
+ batch: next batch of input data
+ """
pass
From fbbeafa65541419f8cc1879948d112d34539fb5c Mon Sep 17 00:00:00 2001
From: Frank Schneider
Date: Mon, 23 Jun 2025 10:17:59 +0200
Subject: [PATCH 102/123] Format scoring/
---
.../generate_held_out_workloads.py | 70 +++---
scoring/algoperf_v05/score_submissions.py | 197 ++++++++-------
scoring/compute_speedups.py | 71 +++---
scoring/performance_profile.py | 218 +++++++++-------
scoring/score_submissions.py | 190 ++++++++------
scoring/scoring_utils.py | 169 ++++++-------
scoring/test_performance_profile.py | 4 +-
scoring/test_scoring_utils.py | 4 +-
scoring/utils/package_logs.py | 12 +-
scoring/utils/run_workloads.py | 232 ++++++++++--------
scoring/utils/slurm/README.md | 41 ++--
.../utils/slurm/algoperf_slurm_cluster.yaml | 130 +++++-----
.../slurm/algoperf_slurm_pakcer_builder.yaml | 169 +++++++------
scoring/utils/slurm/config.json | 210 ++++++++--------
scoring/utils/slurm/make_job_config.py | 72 +++---
.../workload_metadata_external_tuning.json | 66 ++---
.../utils/workload_metadata_self_tuning.json | 66 ++---
17 files changed, 1030 insertions(+), 891 deletions(-)
diff --git a/scoring/algoperf_v05/generate_held_out_workloads.py b/scoring/algoperf_v05/generate_held_out_workloads.py
index 647dc3c3d..e9ebf6a53 100644
--- a/scoring/algoperf_v05/generate_held_out_workloads.py
+++ b/scoring/algoperf_v05/generate_held_out_workloads.py
@@ -2,49 +2,51 @@
import os
import struct
-from absl import app
-from absl import flags
-from absl import logging
import numpy as np
+from absl import app, flags, logging
flags.DEFINE_integer(
- 'held_out_workloads_seed',
- None,
- 'Random seed for scoring.'
- 'AlgoPerf v0.5 seed: 3438810845')
-flags.DEFINE_string('output_filename',
- 'held_out_workloads.json',
- 'Path to file to record sampled held_out workloads.')
+ 'held_out_workloads_seed',
+ None,
+ 'Random seed for scoring.AlgoPerf v0.5 seed: 3438810845',
+)
+flags.DEFINE_string(
+ 'output_filename',
+ 'held_out_workloads.json',
+ 'Path to file to record sampled held_out workloads.',
+)
FLAGS = flags.FLAGS
HELD_OUT_WORKLOADS = {
- 'librispeech': [
- 'librispeech_conformer_attention_temperature',
- 'librispeech_conformer_layernorm',
- # 'librispeech_conformer_gelu', # Removed due to bug in target setting procedure
- 'librispeech_deepspeech_no_resnet',
- 'librispeech_deepspeech_norm_and_spec_aug',
- 'librispeech_deepspeech_tanh'
- ],
- 'imagenet': [
- 'imagenet_resnet_silu',
- 'imagenet_resnet_gelu',
- 'imagenet_resnet_large_bn_init',
- 'imagenet_vit_glu',
- 'imagenet_vit_post_ln',
- 'imagenet_vit_map'
- ],
- 'ogbg': ['ogbg_gelu', 'ogbg_silu', 'ogbg_model_size'],
- 'wmt': ['wmt_post_ln', 'wmt_attention_temp', 'wmt_glu_tanh'],
- 'fastmri': ['fastmri_model_size', 'fastmri_tanh', 'fastmri_layernorm'],
- 'criteo1tb': [
- 'criteo1tb_layernorm', 'criteo1tb_embed_init', 'criteo1tb_resnet'
- ]
+ 'librispeech': [
+ 'librispeech_conformer_attention_temperature',
+ 'librispeech_conformer_layernorm',
+ # 'librispeech_conformer_gelu', # Removed due to bug in target setting procedure
+ 'librispeech_deepspeech_no_resnet',
+ 'librispeech_deepspeech_norm_and_spec_aug',
+ 'librispeech_deepspeech_tanh',
+ ],
+ 'imagenet': [
+ 'imagenet_resnet_silu',
+ 'imagenet_resnet_gelu',
+ 'imagenet_resnet_large_bn_init',
+ 'imagenet_vit_glu',
+ 'imagenet_vit_post_ln',
+ 'imagenet_vit_map',
+ ],
+ 'ogbg': ['ogbg_gelu', 'ogbg_silu', 'ogbg_model_size'],
+ 'wmt': ['wmt_post_ln', 'wmt_attention_temp', 'wmt_glu_tanh'],
+ 'fastmri': ['fastmri_model_size', 'fastmri_tanh', 'fastmri_layernorm'],
+ 'criteo1tb': [
+ 'criteo1tb_layernorm',
+ 'criteo1tb_embed_init',
+ 'criteo1tb_resnet',
+ ],
}
def save_held_out_workloads(held_out_workloads, filename):
- with open(filename, "w") as f:
+ with open(filename, 'w') as f:
json.dump(held_out_workloads, f)
@@ -63,7 +65,7 @@ def main(_):
sampled_index = rng.integers(len(v))
sampled_held_out_workloads.append(v[sampled_index])
- logging.info(f"Sampled held-out workloads: {sampled_held_out_workloads}")
+ logging.info(f'Sampled held-out workloads: {sampled_held_out_workloads}')
save_held_out_workloads(sampled_held_out_workloads, output_filename)
diff --git a/scoring/algoperf_v05/score_submissions.py b/scoring/algoperf_v05/score_submissions.py
index 8cc06b15f..6ef931a54 100644
--- a/scoring/algoperf_v05/score_submissions.py
+++ b/scoring/algoperf_v05/score_submissions.py
@@ -16,57 +16,62 @@
import os
import pickle
-from absl import app
-from absl import flags
-from absl import logging
import numpy as np
import pandas as pd
import performance_profile
import scoring_utils
+from absl import app, flags, logging
from tabulate import tabulate
flags.DEFINE_string(
- 'submission_directory',
- None,
- 'Path to submission directory containing experiment directories.')
+ 'submission_directory',
+ None,
+ 'Path to submission directory containing experiment directories.',
+)
flags.DEFINE_string(
- 'output_dir',
- 'scoring_results',
- 'Path to save performance profile artifacts, submission_summaries and results files.'
+ 'output_dir',
+ 'scoring_results',
+ 'Path to save performance profile artifacts, submission_summaries and results files.',
+)
+flags.DEFINE_boolean(
+ 'compute_performance_profiles',
+ False,
+ 'Whether or not to compute the performance profiles.',
)
-flags.DEFINE_boolean('compute_performance_profiles',
- False,
- 'Whether or not to compute the performance profiles.')
flags.DEFINE_boolean(
- 'strict',
- False,
- 'Whether to enforce scoring criteria on variant performance and on'
- '5-trial median performance. Note that during official scoring this '
- 'flag will be set to True.')
+ 'strict',
+ False,
+ 'Whether to enforce scoring criteria on variant performance and on'
+ '5-trial median performance. Note that during official scoring this '
+ 'flag will be set to True.',
+)
flags.DEFINE_boolean(
- 'self_tuning_ruleset',
- False,
- 'Whether to score on self-tuning ruleset or externally tuned ruleset')
+ 'self_tuning_ruleset',
+ False,
+ 'Whether to score on self-tuning ruleset or externally tuned ruleset',
+)
flags.DEFINE_string(
- 'save_results_to_filename',
- None,
- 'Filename to save the processed results that are fed into the performance profile functions.'
+ 'save_results_to_filename',
+ None,
+ 'Filename to save the processed results that are fed into the performance profile functions.',
)
flags.DEFINE_string(
- 'load_results_from_filename',
- None,
- 'Filename to load processed results from that are fed into performance profile functions'
+ 'load_results_from_filename',
+ None,
+ 'Filename to load processed results from that are fed into performance profile functions',
)
flags.DEFINE_string(
- 'exclude_submissions',
- '',
- 'Optional comma seperated list of names of submissions to exclude from scoring.'
+ 'exclude_submissions',
+ '',
+ 'Optional comma seperated list of names of submissions to exclude from scoring.',
)
FLAGS = flags.FLAGS
def get_summary_df(workload, workload_df, include_test_split=False):
- validation_metric, validation_target = scoring_utils.get_workload_metrics_and_targets(workload, split='validation')
+ validation_metric, validation_target = (
+ scoring_utils.get_workload_metrics_and_targets(workload, split='validation')
+ )
is_minimized = performance_profile.check_if_minimized(validation_metric)
target_op = operator.le if is_minimized else operator.ge
@@ -79,47 +84,69 @@ def get_summary_df(workload, workload_df, include_test_split=False):
summary_df['val target metric name'] = validation_metric
summary_df['val target metric value'] = validation_target
- summary_df['val target reached'] = workload_df[validation_metric].apply(
- lambda x: target_op(x, validation_target)).apply(np.any)
+ summary_df['val target reached'] = (
+ workload_df[validation_metric]
+ .apply(lambda x: target_op(x, validation_target))
+ .apply(np.any)
+ )
summary_df['best metric value on val'] = workload_df[validation_metric].apply(
- lambda x: best_op(x))
+ lambda x: best_op(x)
+ )
workload_df['index best eval on val'] = workload_df[validation_metric].apply(
- lambda x: idx_op(x))
+ lambda x: idx_op(x)
+ )
summary_df['time to best eval on val (s)'] = workload_df.apply(
- lambda x: x['accumulated_submission_time'][x['index best eval on val']],
- axis=1)
- workload_df['val target reached'] = workload_df[validation_metric].apply(
- lambda x: target_op(x, validation_target)).apply(np.any)
+ lambda x: x['accumulated_submission_time'][x['index best eval on val']],
+ axis=1,
+ )
+ workload_df['val target reached'] = (
+ workload_df[validation_metric]
+ .apply(lambda x: target_op(x, validation_target))
+ .apply(np.any)
+ )
workload_df['index to target on val'] = workload_df.apply(
- lambda x: np.argmax(target_op(x[validation_metric], validation_target))
- if x['val target reached'] else np.nan,
- axis=1)
+ lambda x: np.argmax(target_op(x[validation_metric], validation_target))
+ if x['val target reached']
+ else np.nan,
+ axis=1,
+ )
summary_df['time to target on val (s)'] = workload_df.apply(
- lambda x: x['accumulated_submission_time'][int(x[
- 'index to target on val'])] if x['val target reached'] else np.inf,
- axis=1)
+ lambda x: x['accumulated_submission_time'][int(x['index to target on val'])]
+ if x['val target reached']
+ else np.inf,
+ axis=1,
+ )
# test metrics
if include_test_split:
- test_metric, test_target = scoring_utils.get_workload_metrics_and_targets(workload, split='test')
+ test_metric, test_target = scoring_utils.get_workload_metrics_and_targets(
+ workload, split='test'
+ )
summary_df['test target metric name'] = test_metric
summary_df['test target metric value'] = test_target
- summary_df['test target reached'] = workload_df[test_metric].apply(
- lambda x: target_op(x, test_target)).apply(np.any)
+ summary_df['test target reached'] = (
+ workload_df[test_metric]
+ .apply(lambda x: target_op(x, test_target))
+ .apply(np.any)
+ )
summary_df['best metric value on test'] = workload_df[test_metric].apply(
- lambda x: best_op(x))
+ lambda x: best_op(x)
+ )
workload_df['index best eval on test'] = workload_df[test_metric].apply(
- lambda x: idx_op(x))
+ lambda x: idx_op(x)
+ )
summary_df['time to best eval on test (s)'] = workload_df.apply(
- lambda x: x['accumulated_submission_time'][x['index best eval on test']
- ],
- axis=1)
+ lambda x: x['accumulated_submission_time'][x['index best eval on test']],
+ axis=1,
+ )
summary_df['time to target on test (s)'] = summary_df.apply(
- lambda x: x['time to best eval on test (s)']
- if x['test target reached'] else np.inf,
- axis=1)
+ lambda x: x['time to best eval on test (s)']
+ if x['test target reached']
+ else np.inf,
+ axis=1,
+ )
return summary_df
@@ -133,7 +160,8 @@ def get_submission_summary(df, include_test_split=True):
print(df)
for workload, group in df.groupby('workload'):
summary_df = get_summary_df(
- workload, group, include_test_split=include_test_split)
+ workload, group, include_test_split=include_test_split
+ )
dfs.append(summary_df)
df = pd.concat(dfs)
@@ -164,61 +192,64 @@ def main(_):
# Optionally read results to filename
if FLAGS.load_results_from_filename:
with open(
- os.path.join(FLAGS.output_dir, FLAGS.load_results_from_filename),
- 'rb') as f:
+ os.path.join(FLAGS.output_dir, FLAGS.load_results_from_filename), 'rb'
+ ) as f:
results = pickle.load(f)
else:
for team in os.listdir(FLAGS.submission_directory):
for submission in os.listdir(
- os.path.join(FLAGS.submission_directory, team)):
+ os.path.join(FLAGS.submission_directory, team)
+ ):
print(submission)
if submission in FLAGS.exclude_submissions.split(','):
continue
- experiment_path = os.path.join(FLAGS.submission_directory,
- team,
- submission)
+ experiment_path = os.path.join(
+ FLAGS.submission_directory, team, submission
+ )
df = scoring_utils.get_experiment_df(experiment_path)
results[submission] = df
summary_df = get_submission_summary(df)
with open(
- os.path.join(FLAGS.output_dir, f'{submission}_summary.csv'),
- 'w') as fout:
+ os.path.join(FLAGS.output_dir, f'{submission}_summary.csv'), 'w'
+ ) as fout:
summary_df.to_csv(fout)
# Optionally save results to filename
if FLAGS.save_results_to_filename:
with open(
- os.path.join(FLAGS.output_dir, FLAGS.save_results_to_filename),
- 'wb') as f:
+ os.path.join(FLAGS.output_dir, FLAGS.save_results_to_filename), 'wb'
+ ) as f:
pickle.dump(results, f)
if not FLAGS.strict:
logging.warning(
- 'You are running with strict=False. This will relax '
- 'scoring criteria on the held-out workloads, number of trials and number '
- 'of studies. Your score may not be an accurate representation '
- 'under competition scoring rules. To enforce the criteria set strict=True.'
+ 'You are running with strict=False. This will relax '
+ 'scoring criteria on the held-out workloads, number of trials and number '
+ 'of studies. Your score may not be an accurate representation '
+ 'under competition scoring rules. To enforce the criteria set strict=True.'
)
if FLAGS.compute_performance_profiles:
performance_profile_df = performance_profile.compute_performance_profiles(
- results,
- time_col='score',
- min_tau=1.0,
- max_tau=4.0,
- reference_submission_tag=None,
- num_points=100,
- scale='linear',
- verbosity=0,
- self_tuning_ruleset=FLAGS.self_tuning_ruleset,
- strict=FLAGS.strict,
- output_dir=FLAGS.output_dir,
+ results,
+ time_col='score',
+ min_tau=1.0,
+ max_tau=4.0,
+ reference_submission_tag=None,
+ num_points=100,
+ scale='linear',
+ verbosity=0,
+ self_tuning_ruleset=FLAGS.self_tuning_ruleset,
+ strict=FLAGS.strict,
+ output_dir=FLAGS.output_dir,
)
if not os.path.exists(FLAGS.output_dir):
os.mkdir(FLAGS.output_dir)
performance_profile.plot_performance_profiles(
- performance_profile_df, 'score', save_dir=FLAGS.output_dir)
+ performance_profile_df, 'score', save_dir=FLAGS.output_dir
+ )
performance_profile_str = tabulate(
- performance_profile_df.T, headers='keys', tablefmt='psql')
+ performance_profile_df.T, headers='keys', tablefmt='psql'
+ )
logging.info(f'Performance profile:\n {performance_profile_str}')
scores = compute_leaderboard_score(performance_profile_df)
scores.to_csv(os.path.join(FLAGS.output_dir, 'scores.csv'))
diff --git a/scoring/compute_speedups.py b/scoring/compute_speedups.py
index d0e5bf70b..1740a6dce 100644
--- a/scoring/compute_speedups.py
+++ b/scoring/compute_speedups.py
@@ -2,39 +2,39 @@
import pickle
-from absl import app
-from absl import flags
import numpy as np
import pandas as pd
-from performance_profile import BASE_WORKLOADS
-from performance_profile import get_workloads_time_to_target
+from absl import app, flags
+from performance_profile import BASE_WORKLOADS, get_workloads_time_to_target
from scipy import stats
flags.DEFINE_string('results_txt', None, 'Path to full scoring results file.')
flags.DEFINE_string(
- 'base',
- 'prize_qualification_baseline',
- 'Base submission to compare to. Defaults to the `prize_qualification_baseline`.'
+ 'base',
+ 'prize_qualification_baseline',
+ 'Base submission to compare to. Defaults to the `prize_qualification_baseline`.',
)
flags.DEFINE_string('comparison', None, 'Submission to compute the speedup of.')
-flags.DEFINE_boolean('self_tuning_ruleset',
- False,
- 'Whether the self-tuning ruleset is being scored.')
-flags.DEFINE_boolean('save_results',
- False,
- 'Whether to save the results to disk.')
+flags.DEFINE_boolean(
+ 'self_tuning_ruleset',
+ False,
+ 'Whether the self-tuning ruleset is being scored.',
+)
+flags.DEFINE_boolean(
+ 'save_results', False, 'Whether to save the results to disk.'
+)
FLAGS = flags.FLAGS
# These are the old budgets, used in the first iteration of the competition.
MAX_BUDGETS = {
- 'criteo1tb': 7703,
- 'fastmri': 8859,
- 'imagenet_resnet': 63_008,
- 'imagenet_vit': 77_520,
- 'librispeech_conformer': 61_068,
- 'librispeech_deepspeech': 55_506,
- 'ogbg': 18_477,
- 'wmt': 48_151,
+ 'criteo1tb': 7703,
+ 'fastmri': 8859,
+ 'imagenet_resnet': 63_008,
+ 'imagenet_vit': 77_520,
+ 'librispeech_conformer': 61_068,
+ 'librispeech_deepspeech': 55_506,
+ 'ogbg': 18_477,
+ 'wmt': 48_151,
}
@@ -63,16 +63,16 @@ def compute_speedup():
# Compute median over runtimes for both training algorithms
base_results = get_workloads_time_to_target(
- results[FLAGS.base],
- FLAGS.base,
- time_col="score",
- self_tuning_ruleset=FLAGS.self_tuning_ruleset,
+ results[FLAGS.base],
+ FLAGS.base,
+ time_col='score',
+ self_tuning_ruleset=FLAGS.self_tuning_ruleset,
)
comparison_results = get_workloads_time_to_target(
- results[FLAGS.comparison],
- FLAGS.comparison,
- time_col="score",
- self_tuning_ruleset=FLAGS.self_tuning_ruleset,
+ results[FLAGS.comparison],
+ FLAGS.comparison,
+ time_col='score',
+ self_tuning_ruleset=FLAGS.self_tuning_ruleset,
)
# Merge results
@@ -85,20 +85,23 @@ def compute_speedup():
merged_results = merged_results.apply(replace_inf, axis=1)
# Compute speedup
- merged_results['speedup'] = merged_results[
- f'{FLAGS.comparison}'] / merged_results[f'{FLAGS.base}']
+ merged_results['speedup'] = (
+ merged_results[f'{FLAGS.comparison}'] / merged_results[f'{FLAGS.base}']
+ )
speedups = merged_results['speedup'].to_numpy()
mean_speedup = stats.gmean(speedups) # Geometric mean over workload speedups
print(merged_results, end='\n\n')
print(
- f"Average speedup of {FLAGS.comparison} compared to {FLAGS.base}: {mean_speedup} or roughly {(1-mean_speedup):.1%}"
+ f'Average speedup of {FLAGS.comparison} compared to {FLAGS.base}: {mean_speedup} or roughly {(1 - mean_speedup):.1%}'
)
if FLAGS.save_results:
# Optionally save results to disk
- print("Saving results to disk...")
- filename = f'{FLAGS.comparison}_vs_{FLAGS.base}_speedup_{(1-mean_speedup):.1%}.csv'
+ print('Saving results to disk...')
+ filename = (
+ f'{FLAGS.comparison}_vs_{FLAGS.base}_speedup_{(1 - mean_speedup):.1%}.csv'
+ )
merged_results.to_csv(filename)
diff --git a/scoring/performance_profile.py b/scoring/performance_profile.py
index 05026a0c7..d79f705d1 100644
--- a/scoring/performance_profile.py
+++ b/scoring/performance_profile.py
@@ -25,21 +25,22 @@
The keys in this dictionary should match the workload identifiers used in
the dictionary of submissions.
"""
+
import itertools
import json
import operator
import os
import re
-from absl import logging
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
+from absl import logging
from tabulate import tabulate
-from algoperf.workloads.workloads import get_base_workload_name
import algoperf.workloads.workloads as workloads_registry
+from algoperf.workloads.workloads import get_base_workload_name
from scoring import scoring_utils
WORKLOADS = workloads_registry.WORKLOADS
@@ -49,7 +50,7 @@
# Open json file to read heldout workloads
# TODO: This probably shouldn't be hardcoded but passed as an argument.\
try:
- with open("held_out_workloads_algoperf_v05.json", "r") as f:
+ with open('held_out_workloads_algoperf_v05.json', 'r') as f:
HELDOUT_WORKLOADS = json.load(f)
except:
HELDOUT_WORKLOADS = None
@@ -64,22 +65,22 @@
NUM_STUDIES = 3
MIN_EVAL_METRICS = [
- 'ce_loss',
- 'error_rate',
- 'ctc_loss',
- 'wer',
- 'l1_loss',
- 'loss',
+ 'ce_loss',
+ 'error_rate',
+ 'ctc_loss',
+ 'wer',
+ 'l1_loss',
+ 'loss',
]
MAX_EVAL_METRICS = ['mean_average_precision', 'ssim', 'accuracy', 'bleu']
-#MPL params
+# MPL params
mpl.rcParams['figure.figsize'] = (16, 10) # Width, height in inches
mpl.rcParams['font.family'] = 'serif'
-mpl.rcParams['font.serif'] = [
- 'Times New Roman'
-] + mpl.rcParams['font.serif'] # Add Times New Roman as first choice
+mpl.rcParams['font.serif'] = ['Times New Roman'] + mpl.rcParams[
+ 'font.serif'
+] # Add Times New Roman as first choice
mpl.rcParams['font.size'] = 22
mpl.rcParams['savefig.dpi'] = 300 # Set resolution for saved figures
@@ -87,16 +88,17 @@
mpl.rcParams['lines.linewidth'] = 3 # Adjust line thickness if needed
mpl.rcParams['lines.markersize'] = 6 # Adjust marker size if needed
mpl.rcParams['axes.prop_cycle'] = mpl.cycler(
- color=["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728",
- "#9467bd"]) # Example color cycle (consider ColorBrewer or viridis)
+ color=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd']
+) # Example color cycle (consider ColorBrewer or viridis)
mpl.rcParams['axes.labelsize'] = 22 # Axis label font size
mpl.rcParams['xtick.labelsize'] = 20 # Tick label font size
mpl.rcParams['ytick.labelsize'] = 20
# Legends and Gridlines
mpl.rcParams['legend.fontsize'] = 20 # Legend font size
-mpl.rcParams[
- 'legend.loc'] = 'best' # Let matplotlib decide the best legend location
+mpl.rcParams['legend.loc'] = (
+ 'best' # Let matplotlib decide the best legend location
+)
mpl.rcParams['axes.grid'] = True # Enable grid
mpl.rcParams['grid.alpha'] = 0.4 # Gridline transparency
@@ -113,7 +115,8 @@ def generate_eval_cols(metrics):
MINIMIZE_REGISTRY = {k: True for k in generate_eval_cols(MIN_EVAL_METRICS)}
MINIMIZE_REGISTRY.update(
- {k: False for k in generate_eval_cols(MAX_EVAL_METRICS)})
+ {k: False for k in generate_eval_cols(MAX_EVAL_METRICS)}
+)
MINIMIZE_REGISTRY['train_cost'] = True
@@ -125,13 +128,15 @@ def check_if_minimized(col_name):
if col in col_name:
return MINIMIZE_REGISTRY[col]
- raise ValueError(f'Column {col_name} not found in `MINIMIZE_REGISTRY` as '
- 'either a column name or a substring of a column name.')
+ raise ValueError(
+ f'Column {col_name} not found in `MINIMIZE_REGISTRY` as '
+ 'either a column name or a substring of a column name.'
+ )
-def get_best_trial_index(workload_df,
- validation_metric,
- validation_target=None):
+def get_best_trial_index(
+ workload_df, validation_metric, validation_target=None
+):
"""Get the eval index in which a workload reaches the target metric_col.
Args:
@@ -150,7 +155,8 @@ def get_best_trial_index(workload_df,
op = operator.le if is_minimized else operator.ge
validation_target_reached = validation_series.apply(
- lambda x: op(x, validation_target))
+ lambda x: op(x, validation_target)
+ )
target_reached = pd.Series(validation_target_reached)
# Remove trials that never reach the target
@@ -166,12 +172,14 @@ def get_best_trial_index(workload_df,
return trial, index_reached[trial]
-def get_workloads_time_to_target(submission,
- submission_name,
- time_col='global_step',
- verbosity=1,
- self_tuning_ruleset=False,
- strict=False):
+def get_workloads_time_to_target(
+ submission,
+ submission_name,
+ time_col='global_step',
+ verbosity=1,
+ self_tuning_ruleset=False,
+ strict=False,
+):
"""Get times to target for each workload in a submission.
Args:
@@ -191,60 +199,72 @@ def get_workloads_time_to_target(submission,
if num_workloads != NUM_BASE_WORKLOADS + NUM_VARIANT_WORKLOADS:
if strict:
raise ValueError(
- f'Expecting {NUM_BASE_WORKLOADS + NUM_VARIANT_WORKLOADS} workloads '
- f'but found {num_workloads} workloads for {submission_name}.')
- logging.warning(
f'Expecting {NUM_BASE_WORKLOADS + NUM_VARIANT_WORKLOADS} workloads '
- f'but found {num_workloads} workloads for {submission_name}.')
+ f'but found {num_workloads} workloads for {submission_name}.'
+ )
+ logging.warning(
+ f'Expecting {NUM_BASE_WORKLOADS + NUM_VARIANT_WORKLOADS} workloads '
+ f'but found {num_workloads} workloads for {submission_name}.'
+ )
# For each workload get submission time get the submission times to target.
for workload, group in submission.groupby('workload'):
- validation_metric, validation_target = scoring_utils.get_workload_metrics_and_targets(workload)
+ validation_metric, validation_target = (
+ scoring_utils.get_workload_metrics_and_targets(workload)
+ )
# Check number of studies
time_vals_per_study = []
num_studies = len(group.groupby('study'))
if num_studies != NUM_STUDIES:
if strict:
- raise ValueError(f'Expecting {NUM_STUDIES} studies for workload '
- f'{workload} but found {num_studies} studies '
- f'for {submission_name}.')
+ raise ValueError(
+ f'Expecting {NUM_STUDIES} studies for workload '
+ f'{workload} but found {num_studies} studies '
+ f'for {submission_name}.'
+ )
else:
- logging.warning(f'Expecting {NUM_STUDIES} studies for workload '
- f'{workload} but found {num_studies} studies '
- f'for {submission_name}.')
+ logging.warning(
+ f'Expecting {NUM_STUDIES} studies for workload '
+ f'{workload} but found {num_studies} studies '
+ f'for {submission_name}.'
+ )
# For each study check trials
for study, group in group.groupby('study'):
-
# Check number of trials per study
num_trials = len(group)
if num_trials != NUM_TRIALS and not self_tuning_ruleset:
if strict:
raise ValueError(
- f'In Study {study}: Expecting {NUM_TRIALS} trials for workload '
- f'{workload} but found {num_trials} trials '
- f'for {submission_name}.')
+ f'In Study {study}: Expecting {NUM_TRIALS} trials for workload '
+ f'{workload} but found {num_trials} trials '
+ f'for {submission_name}.'
+ )
else:
logging.warning(
- f'In Study {study}: Expecting {NUM_TRIALS} trials for workload '
- f'{workload} but found {num_trials} trials '
- f'for {submission_name}.')
+ f'In Study {study}: Expecting {NUM_TRIALS} trials for workload '
+ f'{workload} but found {num_trials} trials '
+ f'for {submission_name}.'
+ )
# Get trial and time index that reaches target
trial_idx, time_idx = get_best_trial_index(
- group, validation_metric, validation_target)
+ group, validation_metric, validation_target
+ )
if time_idx > -1:
time_val = group[time_col].loc[trial_idx][time_idx]
else:
time_val = float('inf')
time_vals_per_study.append(time_val)
- workloads.append({
+ workloads.append(
+ {
'submission': submission_name,
'workload': re.sub(r'_(jax|pytorch)$', '', workload),
time_col: np.median(time_vals_per_study),
- })
+ }
+ )
df = pd.DataFrame.from_records(workloads)
df = df.pivot(index='submission', columns='workload', values=time_col)
@@ -252,7 +272,6 @@ def get_workloads_time_to_target(submission,
def variant_criteria_filter(base_workload, variant_workload):
-
def filter(x):
try:
if x[variant_workload] == np.inf:
@@ -269,17 +288,19 @@ def filter(x):
return filter
-def compute_performance_profiles(submissions,
- time_col='global_step',
- min_tau=1.0,
- max_tau=None,
- reference_submission_tag=None,
- num_points=100,
- scale='linear',
- verbosity=0,
- strict=False,
- self_tuning_ruleset=False,
- output_dir=None):
+def compute_performance_profiles(
+ submissions,
+ time_col='global_step',
+ min_tau=1.0,
+ max_tau=None,
+ reference_submission_tag=None,
+ num_points=100,
+ scale='linear',
+ verbosity=0,
+ strict=False,
+ self_tuning_ruleset=False,
+ output_dir=None,
+):
"""Compute performance profiles for a set of submission by some time column.
Args:
@@ -308,16 +329,20 @@ def compute_performance_profiles(submissions,
for submission_tag, submission in submissions.items():
logging.info(
- f'\nComputing performance profile with respect to `{time_col}` for '
- f'{submission_tag}')
+ f'\nComputing performance profile with respect to `{time_col}` for '
+ f'{submission_tag}'
+ )
# Get time to targets for each submission across studies and trials
dfs.append(
- get_workloads_time_to_target(submission,
- submission_tag,
- time_col,
- verbosity,
- self_tuning_ruleset,
- strict))
+ get_workloads_time_to_target(
+ submission,
+ submission_tag,
+ time_col,
+ verbosity,
+ self_tuning_ruleset,
+ strict,
+ )
+ )
df = pd.concat(dfs)
# Restrict to base and sampled held-out workloads
# (ignore the additional workload variants of the baseline
@@ -335,7 +360,8 @@ def compute_performance_profiles(submissions,
# If base do not have finite score set variant score to inf
base_workload = get_base_workload_name(workload)
df[workload] = df.apply(
- variant_criteria_filter(workload, base_workload), axis=1)
+ variant_criteria_filter(workload, base_workload), axis=1
+ )
# Set score to inf if not within 4x of fastest submission
best_scores = df.min(axis=0)
@@ -347,17 +373,20 @@ def compute_performance_profiles(submissions,
# If variants do not have finite score set base_workload score to inf
base_workload = get_base_workload_name(workload)
df[base_workload] = df.apply(
- variant_criteria_filter(base_workload, workload), axis=1)
+ variant_criteria_filter(base_workload, workload), axis=1
+ )
df = df[BASE_WORKLOADS]
if verbosity > 0:
logging.info('\n`{time_col}` to reach target:')
- with pd.option_context('display.max_rows',
- None,
- 'display.max_columns',
- None,
- 'display.width',
- 1000):
+ with pd.option_context(
+ 'display.max_rows',
+ None,
+ 'display.max_columns',
+ None,
+ 'display.width',
+ 1000,
+ ):
logging.info(df)
# Divide by the fastest.
@@ -368,12 +397,14 @@ def compute_performance_profiles(submissions,
if verbosity > 0:
logging.info('\n`{time_col}` to reach target normalized to best:')
- with pd.option_context('display.max_rows',
- None,
- 'display.max_columns',
- None,
- 'display.width',
- 1000):
+ with pd.option_context(
+ 'display.max_rows',
+ None,
+ 'display.max_columns',
+ None,
+ 'display.width',
+ 1000,
+ ):
logging.info(df)
# If no max_tau is supplied, choose the value of tau that would plot all non
@@ -385,7 +416,8 @@ def compute_performance_profiles(submissions,
points = np.linspace(min_tau, max_tau, num=num_points)
elif scale == 'log':
points = np.logspace(
- np.log10(min_tau), np.log10(max_tau), num=num_points, base=10.0)
+ np.log10(min_tau), np.log10(max_tau), num=num_points, base=10.0
+ )
def rho(r, tau):
return (r <= tau).sum(axis=1) / NUM_BASE_WORKLOADS
@@ -431,11 +463,9 @@ def maybe_save_df_to_csv(save_dir, df, path, **to_csv_kwargs):
df.to_csv(fout, **to_csv_kwargs)
-def plot_performance_profiles(perf_df,
- df_col,
- scale='linear',
- save_dir=None,
- figsize=(30, 10)):
+def plot_performance_profiles(
+ perf_df, df_col, scale='linear', save_dir=None, figsize=(30, 10)
+):
"""Plot performance profiles.
Args:
@@ -462,6 +492,6 @@ def plot_performance_profiles(perf_df,
fig.legend(bbox_to_anchor=(1.0, 1.0))
plt.tight_layout()
maybe_save_figure(save_dir, f'performance_profile_by_{df_col_display}')
- maybe_save_df_to_csv(save_dir,
- perf_df,
- f'performance_profile_{df_col_display}.csv')
+ maybe_save_df_to_csv(
+ save_dir, perf_df, f'performance_profile_{df_col_display}.csv'
+ )
diff --git a/scoring/score_submissions.py b/scoring/score_submissions.py
index f07dc8cdd..b48509f02 100644
--- a/scoring/score_submissions.py
+++ b/scoring/score_submissions.py
@@ -17,57 +17,62 @@
import os
import pickle
-from absl import app
-from absl import flags
-from absl import logging
import numpy as np
import pandas as pd
import performance_profile
import scoring_utils
+from absl import app, flags, logging
from tabulate import tabulate
flags.DEFINE_string(
- 'submission_directory',
- None,
- 'Path to submission directory containing experiment directories.')
+ 'submission_directory',
+ None,
+ 'Path to submission directory containing experiment directories.',
+)
flags.DEFINE_string(
- 'output_dir',
- 'scoring_results',
- 'Path to save performance profile artifacts, submission_summaries and results files.'
+ 'output_dir',
+ 'scoring_results',
+ 'Path to save performance profile artifacts, submission_summaries and results files.',
+)
+flags.DEFINE_boolean(
+ 'compute_performance_profiles',
+ False,
+ 'Whether or not to compute the performance profiles.',
)
-flags.DEFINE_boolean('compute_performance_profiles',
- False,
- 'Whether or not to compute the performance profiles.')
flags.DEFINE_boolean(
- 'strict',
- False,
- 'Whether to enforce scoring criteria on variant performance and on'
- '5-trial median performance. Note that during official scoring this '
- 'flag will be set to True.')
+ 'strict',
+ False,
+ 'Whether to enforce scoring criteria on variant performance and on'
+ '5-trial median performance. Note that during official scoring this '
+ 'flag will be set to True.',
+)
flags.DEFINE_boolean(
- 'self_tuning_ruleset',
- False,
- 'Whether to score on self-tuning ruleset or externally tuned ruleset')
+ 'self_tuning_ruleset',
+ False,
+ 'Whether to score on self-tuning ruleset or externally tuned ruleset',
+)
flags.DEFINE_string(
- 'save_results_to_filename',
- None,
- 'Filename to save the processed results that are fed into the performance profile functions.'
+ 'save_results_to_filename',
+ None,
+ 'Filename to save the processed results that are fed into the performance profile functions.',
)
flags.DEFINE_string(
- 'load_results_from_filename',
- None,
- 'Filename to load processed results from that are fed into performance profile functions'
+ 'load_results_from_filename',
+ None,
+ 'Filename to load processed results from that are fed into performance profile functions',
)
flags.DEFINE_string(
- 'exclude_submissions',
- '',
- 'Optional comma seperated list of names of submissions to exclude from scoring.'
+ 'exclude_submissions',
+ '',
+ 'Optional comma seperated list of names of submissions to exclude from scoring.',
)
FLAGS = flags.FLAGS
def get_summary_df(workload, workload_df, include_test_split=False):
- validation_metric, validation_target = scoring_utils.get_workload_metrics_and_targets(workload, split='validation')
+ validation_metric, validation_target = (
+ scoring_utils.get_workload_metrics_and_targets(workload, split='validation')
+ )
is_minimized = performance_profile.check_if_minimized(validation_metric)
target_op = operator.le if is_minimized else operator.ge
@@ -80,47 +85,69 @@ def get_summary_df(workload, workload_df, include_test_split=False):
summary_df['val target metric name'] = validation_metric
summary_df['val target metric value'] = validation_target
- summary_df['val target reached'] = workload_df[validation_metric].apply(
- lambda x: target_op(x, validation_target)).apply(np.any)
+ summary_df['val target reached'] = (
+ workload_df[validation_metric]
+ .apply(lambda x: target_op(x, validation_target))
+ .apply(np.any)
+ )
summary_df['best metric value on val'] = workload_df[validation_metric].apply(
- lambda x: best_op(x))
+ lambda x: best_op(x)
+ )
workload_df['index best eval on val'] = workload_df[validation_metric].apply(
- lambda x: idx_op(x))
+ lambda x: idx_op(x)
+ )
summary_df['time to best eval on val (s)'] = workload_df.apply(
- lambda x: x['accumulated_submission_time'][x['index best eval on val']],
- axis=1)
- workload_df['val target reached'] = workload_df[validation_metric].apply(
- lambda x: target_op(x, validation_target)).apply(np.any)
+ lambda x: x['accumulated_submission_time'][x['index best eval on val']],
+ axis=1,
+ )
+ workload_df['val target reached'] = (
+ workload_df[validation_metric]
+ .apply(lambda x: target_op(x, validation_target))
+ .apply(np.any)
+ )
workload_df['index to target on val'] = workload_df.apply(
- lambda x: np.argmax(target_op(x[validation_metric], validation_target))
- if x['val target reached'] else np.nan,
- axis=1)
+ lambda x: np.argmax(target_op(x[validation_metric], validation_target))
+ if x['val target reached']
+ else np.nan,
+ axis=1,
+ )
summary_df['time to target on val (s)'] = workload_df.apply(
- lambda x: x['accumulated_submission_time'][int(x[
- 'index to target on val'])] if x['val target reached'] else np.inf,
- axis=1)
+ lambda x: x['accumulated_submission_time'][int(x['index to target on val'])]
+ if x['val target reached']
+ else np.inf,
+ axis=1,
+ )
# test metrics
if include_test_split:
- test_metric, test_target = scoring_utils.get_workload_metrics_and_targets(workload, split='test')
+ test_metric, test_target = scoring_utils.get_workload_metrics_and_targets(
+ workload, split='test'
+ )
summary_df['test target metric name'] = test_metric
summary_df['test target metric value'] = test_target
- summary_df['test target reached'] = workload_df[test_metric].apply(
- lambda x: target_op(x, test_target)).apply(np.any)
+ summary_df['test target reached'] = (
+ workload_df[test_metric]
+ .apply(lambda x: target_op(x, test_target))
+ .apply(np.any)
+ )
summary_df['best metric value on test'] = workload_df[test_metric].apply(
- lambda x: best_op(x))
+ lambda x: best_op(x)
+ )
workload_df['index best eval on test'] = workload_df[test_metric].apply(
- lambda x: idx_op(x))
+ lambda x: idx_op(x)
+ )
summary_df['time to best eval on test (s)'] = workload_df.apply(
- lambda x: x['accumulated_submission_time'][x['index best eval on test']
- ],
- axis=1)
+ lambda x: x['accumulated_submission_time'][x['index best eval on test']],
+ axis=1,
+ )
summary_df['time to target on test (s)'] = summary_df.apply(
- lambda x: x['time to best eval on test (s)']
- if x['test target reached'] else np.inf,
- axis=1)
+ lambda x: x['time to best eval on test (s)']
+ if x['test target reached']
+ else np.inf,
+ axis=1,
+ )
return summary_df
@@ -134,7 +161,8 @@ def get_submission_summary(df, include_test_split=True):
print(df)
for workload, group in df.groupby('workload'):
summary_df = get_summary_df(
- workload, group, include_test_split=include_test_split)
+ workload, group, include_test_split=include_test_split
+ )
dfs.append(summary_df)
df = pd.concat(dfs)
@@ -161,13 +189,13 @@ def compute_leaderboard_score(df, normalize=True):
def main(_):
results = {}
os.makedirs(FLAGS.output_dir, exist_ok=True)
- logging.info(f"Scoring submissions in {FLAGS.submission_directory}")
+ logging.info(f'Scoring submissions in {FLAGS.submission_directory}')
# Optionally read results to filename
if FLAGS.load_results_from_filename:
with open(
- os.path.join(FLAGS.output_dir, FLAGS.load_results_from_filename),
- 'rb') as f:
+ os.path.join(FLAGS.output_dir, FLAGS.load_results_from_filename), 'rb'
+ ) as f:
results = pickle.load(f)
else:
for submission in os.listdir(FLAGS.submission_directory):
@@ -179,44 +207,46 @@ def main(_):
results[submission] = df
summary_df = get_submission_summary(df)
with open(
- os.path.join(FLAGS.output_dir, f'{submission}_summary.csv'),
- 'w') as fout:
+ os.path.join(FLAGS.output_dir, f'{submission}_summary.csv'), 'w'
+ ) as fout:
summary_df.to_csv(fout)
# Optionally save results to filename
if FLAGS.save_results_to_filename:
with open(
- os.path.join(FLAGS.output_dir, FLAGS.save_results_to_filename),
- 'wb') as f:
+ os.path.join(FLAGS.output_dir, FLAGS.save_results_to_filename), 'wb'
+ ) as f:
pickle.dump(results, f)
if not FLAGS.strict:
logging.warning(
- 'You are running with strict=False. This will relax '
- 'scoring criteria on the held-out workloads, number of trials and number '
- 'of studies. Your score may not be an accurate representation '
- 'under competition scoring rules. To enforce the criteria set strict=True.'
+ 'You are running with strict=False. This will relax '
+ 'scoring criteria on the held-out workloads, number of trials and number '
+ 'of studies. Your score may not be an accurate representation '
+ 'under competition scoring rules. To enforce the criteria set strict=True.'
)
if FLAGS.compute_performance_profiles:
performance_profile_df = performance_profile.compute_performance_profiles(
- results,
- time_col='score',
- min_tau=1.0,
- max_tau=4.0,
- reference_submission_tag=None,
- num_points=100,
- scale='linear',
- verbosity=0,
- self_tuning_ruleset=FLAGS.self_tuning_ruleset,
- strict=FLAGS.strict,
- output_dir=FLAGS.output_dir,
+ results,
+ time_col='score',
+ min_tau=1.0,
+ max_tau=4.0,
+ reference_submission_tag=None,
+ num_points=100,
+ scale='linear',
+ verbosity=0,
+ self_tuning_ruleset=FLAGS.self_tuning_ruleset,
+ strict=FLAGS.strict,
+ output_dir=FLAGS.output_dir,
)
if not os.path.exists(FLAGS.output_dir):
os.mkdir(FLAGS.output_dir)
performance_profile.plot_performance_profiles(
- performance_profile_df, 'score', save_dir=FLAGS.output_dir)
+ performance_profile_df, 'score', save_dir=FLAGS.output_dir
+ )
performance_profile_str = tabulate(
- performance_profile_df.T, headers='keys', tablefmt='psql')
+ performance_profile_df.T, headers='keys', tablefmt='psql'
+ )
logging.info(f'Performance profile:\n {performance_profile_str}')
scores = compute_leaderboard_score(performance_profile_df)
scores.to_csv(os.path.join(FLAGS.output_dir, 'scores.csv'))
diff --git a/scoring/scoring_utils.py b/scoring/scoring_utils.py
index ac513816e..ab639f870 100644
--- a/scoring/scoring_utils.py
+++ b/scoring/scoring_utils.py
@@ -4,8 +4,8 @@
import os
import re
-from absl import logging
import pandas as pd
+from absl import logging
import algoperf.workloads.workloads as workloads_registry
@@ -13,7 +13,7 @@
METRICS_LINE_REGEX = '(.*) Metrics: ({.*})'
TRIAL_DIR_REGEX = 'trial_(\d+)'
MEASUREMENTS_FILENAME = 'eval_measurements.csv'
-TIMESTAMP = r"-\d{4}(-\d{2}){5}"
+TIMESTAMP = r'-\d{4}(-\d{2}){5}'
WORKLOADS = workloads_registry.WORKLOADS
WORKLOAD_NAME_PATTERN = '(.*)(_jax|_pytorch)'
@@ -22,12 +22,11 @@
#### File IO helper functions ###
def get_logfile_paths(logdir):
- """Gets all files ending in .log in logdir
- """
+ """Gets all files ending in .log in logdir"""
filenames = os.listdir(logdir)
logfile_paths = []
for f in filenames:
- if f.endswith(".log"):
+ if f.endswith('.log'):
f = os.path.join(logdir, f)
logfile_paths.append(f)
return logfile_paths
@@ -36,23 +35,23 @@ def get_logfile_paths(logdir):
### Logfile reading helper functions ###
def decode_metrics_line(line):
"""Convert metrics line to dict.
- Args:
- line: str
-
- Returns:
- dict_of_lists: dict where keys are metric names and vals
- are lists of values.
- e.g. {'loss':[5.1, 3.2, 1.0],
- 'step':[100, 200, 300]}
- """
+ Args:
+ line: str
+
+ Returns:
+ dict_of_lists: dict where keys are metric names and vals
+ are lists of values.
+ e.g. {'loss':[5.1, 3.2, 1.0],
+ 'step':[100, 200, 300]}
+ """
eval_results = []
dict_str = re.match(METRICS_LINE_REGEX, line).group(2)
- dict_str = dict_str.replace("'", "\"")
- dict_str = dict_str.replace("(", "")
- dict_str = dict_str.replace(")", "")
- dict_str = dict_str.replace("DeviceArray", "")
- dict_str = dict_str.replace(", dtype=float32", "")
- dict_str = dict_str.replace("nan", "0")
+ dict_str = dict_str.replace("'", '"')
+ dict_str = dict_str.replace('(', '')
+ dict_str = dict_str.replace(')', '')
+ dict_str = dict_str.replace('DeviceArray', '')
+ dict_str = dict_str.replace(', dtype=float32', '')
+ dict_str = dict_str.replace('nan', '0')
metrics_dict = json.loads(dict_str)
for item in metrics_dict['eval_results']:
if isinstance(item, dict):
@@ -73,18 +72,18 @@ def decode_metrics_line(line):
def get_trials_dict(logfile):
- """Get a dict of dicts with metrics for each
- tuning run.
-
- Returns:
- trials_dict: Dict of dicts where outer dict keys
- are trial indices and inner dict key-value pairs
- are metrics and list of values.
- e.g. {'trial_0': {'loss':[5.1, 3.2, 1.0],
- 'step':[100, 200, 300]},
- 'trial_1': {'loss':[5.1, 3.2, 1.0],
- 'step':[100, 200, 300]}}
- """
+ """Get a dict of dicts with metrics for each
+ tuning run.
+
+ Returns:
+ trials_dict: Dict of dicts where outer dict keys
+ are trial indices and inner dict key-value pairs
+ are metrics and list of values.
+ e.g. {'trial_0': {'loss':[5.1, 3.2, 1.0],
+ 'step':[100, 200, 300]},
+ 'trial_1': {'loss':[5.1, 3.2, 1.0],
+ 'step':[100, 200, 300]}}
+ """
trial = 0
metrics_lines = {}
with open(logfile, 'r') as f:
@@ -100,16 +99,16 @@ def get_trials_dict(logfile):
### Results formatting helper functions ###
def get_trials_df_dict(logfile):
- """Get a dict with dataframes with metrics for each
- tuning run.
- Preferable format for saving dataframes for tables.
- Args:
- logfile: str path to logfile.
-
- Returns:
- DataFrame where indices are index of eval and
- columns are metric names.
- """
+ """Get a dict with dataframes with metrics for each
+ tuning run.
+ Preferable format for saving dataframes for tables.
+ Args:
+ logfile: str path to logfile.
+
+ Returns:
+ DataFrame where indices are index of eval and
+ columns are metric names.
+ """
trials_dict = get_trials_dict(logfile)
trials_df_dict = {}
for trial, metrics in trials_dict.items():
@@ -119,20 +118,20 @@ def get_trials_df_dict(logfile):
def get_trials_df(logfile):
"""Gets a df of per trial results from a logfile.
- Args:
- experiment_dir: str
-
- Returns:
- df: DataFrame where indices are trials, columns are
- metric names and values are lists.
- e.g
- +---------+-----------------+-----------------+
- | | loss | step |
- |---------+-----------------+-----------------|
- | trial_0 | [5.1, 3.2, 1.0] | [100, 200, 300] |
- | trial_1 | [5.1, 3.2, 1.0] | [100, 200, 300] |
- +---------+-----------------+-----------------+
- """
+ Args:
+ experiment_dir: str
+
+ Returns:
+ df: DataFrame where indices are trials, columns are
+ metric names and values are lists.
+ e.g
+ +---------+-----------------+-----------------+
+ | | loss | step |
+ |---------+-----------------+-----------------|
+ | trial_0 | [5.1, 3.2, 1.0] | [100, 200, 300] |
+ | trial_1 | [5.1, 3.2, 1.0] | [100, 200, 300] |
+ +---------+-----------------+-----------------+
+ """
trials_dict = get_trials_dict(logfile)
df = pd.DataFrame(trials_dict).transpose()
return df
@@ -141,13 +140,13 @@ def get_trials_df(logfile):
## Get scoring code
def get_experiment_df(experiment_dir):
"""Gets a df of per trial results from an experiment dir.
- The output df can be provided as input to
- performance_profile.compute_performance_profiles.
+ The output df can be provided as input to
+ performance_profile.compute_performance_profiles.
Args:
- experiment_dir: path to experiment directory containing
- results for workloads. Measurements from experiments
- sharing the same prefix but different timestamps are
- collected together.
+ experiment_dir: path to experiment directory containing
+ results for workloads. Measurements from experiments
+ sharing the same prefix but different timestamps are
+ collected together.
The directory structure is assumed to be:
+ experiment_dir
+ study
@@ -156,9 +155,9 @@ def get_experiment_df(experiment_dir):
- eval_measurements.csv
Returns:
- df: DataFrame where indices are trials, columns are
+ df: DataFrame where indices are trials, columns are
metric names and values are lists of length num evals.
- e.g
+ e.g
+----+-----------+--------+----------------------------+--------------------+--------------------+
| | workload | study |trial | validation/accuracy| score |
|----+-----------+--------+----------------------------+--------------------+--------------------|
@@ -167,35 +166,37 @@ def get_experiment_df(experiment_dir):
"""
df = pd.DataFrame()
paths = filter(
- lambda x: re.match(experiment_dir + TIMESTAMP, x) or x == experiment_dir,
- glob.glob(f"{experiment_dir}*"))
+ lambda x: re.match(experiment_dir + TIMESTAMP, x) or x == experiment_dir,
+ glob.glob(f'{experiment_dir}*'),
+ )
for experiment_dir in paths:
study_dirs = os.listdir(experiment_dir)
for study_dir in study_dirs:
workload_dirs = os.listdir(os.path.join(experiment_dir, study_dir))
workload_dirs = [
- w for w in workload_dirs
- if os.path.isdir(os.path.join(experiment_dir, study_dir, w))
+ w
+ for w in workload_dirs
+ if os.path.isdir(os.path.join(experiment_dir, study_dir, w))
]
print(workload_dirs)
for workload in workload_dirs:
data = {
- 'workload': workload,
+ 'workload': workload,
}
logging.info(os.path.join(experiment_dir, study_dir, workload))
trial_dirs = [
- t for t in os.listdir(
- os.path.join(experiment_dir, study_dir, workload))
- if re.match(TRIAL_DIR_REGEX, t)
+ t
+ for t in os.listdir(os.path.join(experiment_dir, study_dir, workload))
+ if re.match(TRIAL_DIR_REGEX, t)
]
for trial in trial_dirs:
eval_measurements_filepath = os.path.join(
- experiment_dir,
- study_dir,
- workload,
- trial,
- MEASUREMENTS_FILENAME,
+ experiment_dir,
+ study_dir,
+ workload,
+ trial,
+ MEASUREMENTS_FILENAME,
)
try:
trial_df = pd.read_csv(eval_measurements_filepath)
@@ -221,14 +222,16 @@ def get_workload_metrics_and_targets(workload, split='validation'):
# Extend path according to framework.
workload_metadata['workload_path'] = os.path.join(
- BASE_WORKLOADS_DIR,
- workload_metadata['workload_path'] + f'{framework}',
- 'workload.py')
+ BASE_WORKLOADS_DIR,
+ workload_metadata['workload_path'] + f'{framework}',
+ 'workload.py',
+ )
workload_init_kwargs = {}
workload_obj = workloads_registry.import_workload(
- workload_path=workload_metadata['workload_path'],
- workload_class_name=workload_metadata['workload_class_name'],
- workload_init_kwargs=workload_init_kwargs)
+ workload_path=workload_metadata['workload_path'],
+ workload_class_name=workload_metadata['workload_class_name'],
+ workload_init_kwargs=workload_init_kwargs,
+ )
metric_name = workload_obj.target_metric_name
if split == 'validation':
metric = f'validation/{metric_name}'
diff --git a/scoring/test_performance_profile.py b/scoring/test_performance_profile.py
index 166c82d09..01c96de71 100644
--- a/scoring/test_performance_profile.py
+++ b/scoring/test_performance_profile.py
@@ -2,12 +2,10 @@
from absl.testing import absltest
-from scoring import performance_profile
-from scoring import scoring_utils
+from scoring import performance_profile, scoring_utils
class Test(absltest.TestCase):
-
def test_get_workloads_time_to_target(self):
# TODO(kasimbeg)
pass
diff --git a/scoring/test_scoring_utils.py b/scoring/test_scoring_utils.py
index 7509e3e46..e3a5f7263 100644
--- a/scoring/test_scoring_utils.py
+++ b/scoring/test_scoring_utils.py
@@ -2,8 +2,7 @@
from absl.testing import absltest
-from scoring import performance_profile
-from scoring import scoring_utils
+from scoring import performance_profile, scoring_utils
TEST_LOGFILE = 'test_data/adamw_fastmri_jax_04-18-2023-13-10-58.log'
TEST_DIR = 'test_data/experiment_dir'
@@ -11,7 +10,6 @@
class Test(absltest.TestCase):
-
def test_get_trials_dict(self):
trials_dict = scoring_utils.get_trials_dict(TEST_LOGFILE)
self.assertEqual(len(trials_dict['1']['global_step']), NUM_EVALS)
diff --git a/scoring/utils/package_logs.py b/scoring/utils/package_logs.py
index 074075abf..e341570a1 100644
--- a/scoring/utils/package_logs.py
+++ b/scoring/utils/package_logs.py
@@ -3,11 +3,11 @@
python3 package_logs.py --experiment_dir --destination_dir
"""
+
import os
import shutil
-from absl import app
-from absl import flags
+from absl import app, flags
flags.DEFINE_string('experiment_dir', None, 'Path to experiment.')
flags.DEFINE_string('destination_dir', None, 'Path to save submission logs')
@@ -17,10 +17,10 @@
def move_logs(experiment_dir, destination_dir):
"""Copy files from experiment path to destination directory.
- Args:
- experiment_dir: Path to experiment dir.
- destination_dir: Path to destination dir.
- """
+ Args:
+ experiment_dir: Path to experiment dir.
+ destination_dir: Path to destination dir.
+ """
if not os.path.exists(experiment_dir):
raise IOError(f'Directory does not exist {destination_dir}')
diff --git a/scoring/utils/run_workloads.py b/scoring/utils/run_workloads.py
index 683fb3c63..e2de01130 100644
--- a/scoring/utils/run_workloads.py
+++ b/scoring/utils/run_workloads.py
@@ -16,101 +16,111 @@
import subprocess
import time
-from absl import app
-from absl import flags
-from absl import logging
+from absl import app, flags, logging
+import docker
from algoperf import random_utils as prng
from algoperf.workloads.workloads import get_base_workload_name
-import docker
flags.DEFINE_string(
- 'docker_image_url',
- 'europe-west4-docker.pkg.dev/mlcommons-algoperf/algoperf-docker-repo/algoperf_jax_dev',
- 'URL to docker image')
+ 'docker_image_url',
+ 'europe-west4-docker.pkg.dev/mlcommons-algoperf/algoperf-docker-repo/algoperf_jax_dev',
+ 'URL to docker image',
+)
flags.DEFINE_integer(
- 'run_percentage',
- 100,
- 'Percentage of max num steps to run for.'
- 'Must set the flag enable_step_budget to True for this to take effect.')
-flags.DEFINE_string('experiment_name',
- 'my_experiment',
- 'Name of top sub directory in experiment dir.')
-flags.DEFINE_boolean('rsync_data',
- True,
- 'Whether or not to transfer the data from GCP w rsync.')
+ 'run_percentage',
+ 100,
+ 'Percentage of max num steps to run for.'
+ 'Must set the flag enable_step_budget to True for this to take effect.',
+)
+flags.DEFINE_string(
+ 'experiment_name',
+ 'my_experiment',
+ 'Name of top sub directory in experiment dir.',
+)
+flags.DEFINE_boolean(
+ 'rsync_data', True, 'Whether or not to transfer the data from GCP w rsync.'
+)
flags.DEFINE_boolean('local', False, 'Mount local algorithmic-efficiency repo.')
flags.DEFINE_string(
- 'submission_path',
- 'prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py',
- 'Path to reference submission.')
+ 'submission_path',
+ 'prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py',
+ 'Path to reference submission.',
+)
flags.DEFINE_string(
- 'tuning_search_space',
- 'prize_qualification_baselines/external_tuning/tuning_search_space.json',
- 'Path to tuning search space.')
+ 'tuning_search_space',
+ 'prize_qualification_baselines/external_tuning/tuning_search_space.json',
+ 'Path to tuning search space.',
+)
flags.DEFINE_string('framework', 'jax', 'Can be either PyTorch or JAX.')
flags.DEFINE_boolean(
- 'dry_run',
- False,
- 'Whether or not to actually run the docker containers. '
- 'If False, simply print the docker run commands. ')
+ 'dry_run',
+ False,
+ 'Whether or not to actually run the docker containers. '
+ 'If False, simply print the docker run commands. ',
+)
flags.DEFINE_enum(
- 'tuning_ruleset',
- 'external',
- enum_values=['external', 'self'],
- help='Can be either external of self.')
+ 'tuning_ruleset',
+ 'external',
+ enum_values=['external', 'self'],
+ help='Can be either external of self.',
+)
flags.DEFINE_integer('num_studies', 5, 'Number of studies to run')
flags.DEFINE_integer('study_start_index', None, 'Start index for studies.')
flags.DEFINE_integer('study_end_index', None, 'End index for studies.')
flags.DEFINE_integer('num_tuning_trials', 5, 'Number of tuning trials.')
-flags.DEFINE_integer('hparam_start_index',
- None,
- 'Start index for tuning trials.')
+flags.DEFINE_integer(
+ 'hparam_start_index', None, 'Start index for tuning trials.'
+)
flags.DEFINE_integer('hparam_end_index', None, 'End index for tuning trials.')
flags.DEFINE_integer('seed', None, 'Random seed for evaluating a submission.')
-flags.DEFINE_integer('submission_id',
- 0,
- 'Submission ID to generate study and hparam seeds.')
-flags.DEFINE_string('held_out_workloads_config_path',
- None,
- 'Path to config containing held-out workloads')
+flags.DEFINE_integer(
+ 'submission_id', 0, 'Submission ID to generate study and hparam seeds.'
+)
flags.DEFINE_string(
- 'workload_metadata_path',
- None,
- 'Path to config containing dataset and maximum number of steps per workload.'
- 'The default values of these are set to the full budgets as determined '
- 'via the target-setting procedure. '
- 'We provide workload_metadata_external_tuning.json and '
- 'workload_metadata_self_tuning.json as references.'
- 'Note that training will be interrupted at either the set maximum number '
- 'of steps or the fixed workload maximum run time, whichever comes first. '
- 'If your algorithm has a smaller per step time than our baselines '
- 'you may want to increase the number of steps per workload.')
+ 'held_out_workloads_config_path',
+ None,
+ 'Path to config containing held-out workloads',
+)
flags.DEFINE_string(
- 'workloads',
- None,
- 'String representing a comma separated list of workload names.'
- 'If not None, only run this workload, else run all workloads in workload_metadata_path.'
+ 'workload_metadata_path',
+ None,
+ 'Path to config containing dataset and maximum number of steps per workload.'
+ 'The default values of these are set to the full budgets as determined '
+ 'via the target-setting procedure. '
+ 'We provide workload_metadata_external_tuning.json and '
+ 'workload_metadata_self_tuning.json as references.'
+ 'Note that training will be interrupted at either the set maximum number '
+ 'of steps or the fixed workload maximum run time, whichever comes first. '
+ 'If your algorithm has a smaller per step time than our baselines '
+ 'you may want to increase the number of steps per workload.',
+)
+flags.DEFINE_string(
+ 'workloads',
+ None,
+ 'String representing a comma separated list of workload names.'
+ 'If not None, only run this workload, else run all workloads in workload_metadata_path.',
+)
+flags.DEFINE_string(
+ 'additional_requirements_path', None, 'Path to requirements.txt if any.'
)
-flags.DEFINE_string('additional_requirements_path',
- None,
- 'Path to requirements.txt if any.')
flags.DEFINE_integer(
- 'max_steps',
- None,
- 'Maximum number of steps to run. Must set flag enable_step_budget.'
- 'This flag takes precedence over the run_percentage flag.')
+ 'max_steps',
+ None,
+ 'Maximum number of steps to run. Must set flag enable_step_budget.'
+ 'This flag takes precedence over the run_percentage flag.',
+)
flags.DEFINE_bool(
- 'enable_step_budget',
- False,
- 'Flag that has to be explicitly set to override time budgets to step budget percentage.'
+ 'enable_step_budget',
+ False,
+ 'Flag that has to be explicitly set to override time budgets to step budget percentage.',
)
FLAGS = flags.FLAGS
def read_held_out_workloads(filename):
- with open(filename, "r") as f:
+ with open(filename, 'r') as f:
held_out_workloads = json.load(f)
return held_out_workloads
@@ -132,11 +142,13 @@ def kill_containers():
def gpu_is_active():
- output = subprocess.check_output([
+ output = subprocess.check_output(
+ [
'nvidia-smi',
'--query-gpu=utilization.gpu',
- '--format=csv,noheader,nounits'
- ])
+ '--format=csv,noheader,nounits',
+ ]
+ )
return any(int(x) > 0 for x in output.decode().splitlines())
@@ -151,7 +163,8 @@ def wait_until_container_not_running(sleep_interval=5 * 60):
gpu_last_active = datetime.datetime.now().timestamp()
if (datetime.datetime.now().timestamp() - gpu_last_active) > 45 * 60:
kill_containers(
- "Killing container: GPUs have been inactive > 45 minutes...")
+ 'Killing container: GPUs have been inactive > 45 minutes...'
+ )
time.sleep(sleep_interval)
return
@@ -167,7 +180,9 @@ def main(_):
hparam_start_index_flag = ''
hparam_end_index_flag = ''
if FLAGS.hparam_start_index:
- hparam_start_index_flag = f'--hparam_start_index {FLAGS.hparam_start_index} '
+ hparam_start_index_flag = (
+ f'--hparam_start_index {FLAGS.hparam_start_index} '
+ )
if FLAGS.hparam_end_index:
hparam_end_index_flag = f'--hparam_end_index {FLAGS.hparam_end_index} '
study_start_index = FLAGS.study_start_index if FLAGS.study_start_index else 0
@@ -178,7 +193,9 @@ def main(_):
additional_requirements_path_flag = ''
if FLAGS.additional_requirements_path:
- additional_requirements_path_flag = f'--additional_requirements_path {FLAGS.additional_requirements_path} '
+ additional_requirements_path_flag = (
+ f'--additional_requirements_path {FLAGS.additional_requirements_path} '
+ )
submission_id = FLAGS.submission_id
@@ -188,7 +205,7 @@ def main(_):
rng_seed = struct.unpack('I', os.urandom(4))[0]
logging.info('Using RNG seed %d', rng_seed)
- rng_key = (prng.fold_in(prng.PRNGKey(rng_seed), hash(submission_id)))
+ rng_key = prng.fold_in(prng.PRNGKey(rng_seed), hash(submission_id))
with open(FLAGS.workload_metadata_path) as f:
workload_metadata = json.load(f)
@@ -199,20 +216,24 @@ def main(_):
# Read heldout workloads
if FLAGS.held_out_workloads_config_path:
held_out_workloads = read_held_out_workloads(
- FLAGS.held_out_workloads_config_path)
+ FLAGS.held_out_workloads_config_path
+ )
workloads = workloads + held_out_workloads
# Filter workloads if explicit workloads specified
if FLAGS.workloads is not None:
workloads = list(
- filter(lambda x: x in FLAGS.workloads.split(','), workloads))
+ filter(lambda x: x in FLAGS.workloads.split(','), workloads)
+ )
if len(workloads) != len(FLAGS.workloads.split(',')):
unmatched_workloads = set(FLAGS.workloads.split(',')) - set(workloads)
raise ValueError(f'Invalid workload name {unmatched_workloads}')
rng_subkeys = prng.split(rng_key, num_studies)
- for study_index, rng_subkey in zip(range(study_start_index, study_end_index + 1), rng_subkeys):
+ for study_index, rng_subkey in zip(
+ range(study_start_index, study_end_index + 1), rng_subkeys
+ ):
print('-' * 100)
print('*' * 40, f'Starting study {study_index + 1}/{num_studies}', '*' * 40)
print('-' * 100)
@@ -225,40 +246,46 @@ def main(_):
base_workload_name = get_base_workload_name(workload)
wait_until_container_not_running()
os.system(
- "sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'") # clear caches
+ "sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'"
+ ) # clear caches
print('=' * 100)
dataset = workload_metadata[base_workload_name]['dataset']
max_steps_flag = ''
if FLAGS.enable_step_budget:
- run_fraction = FLAGS.run_percentage / 100.
+ run_fraction = FLAGS.run_percentage / 100.0
if FLAGS.max_steps is None:
- max_steps = int(workload_metadata[base_workload_name]['max_steps'] *
- run_fraction)
+ max_steps = int(
+ workload_metadata[base_workload_name]['max_steps'] * run_fraction
+ )
else:
max_steps = FLAGS.max_steps
max_steps_flag = f'-m {max_steps}'
mount_repo_flag = ''
if FLAGS.local:
- mount_repo_flag = '-v /home/kasimbeg/algorithmic-efficiency:/algorithmic-efficiency '
- command = ('docker run -t -d -v /home/kasimbeg/data/:/data/ '
- '-v /home/kasimbeg/experiment_runs/:/experiment_runs '
- '-v /home/kasimbeg/experiment_runs/logs:/logs '
- f'{mount_repo_flag}'
- '--gpus all --ipc=host '
- f'{docker_image_url} '
- f'-d {dataset} '
- f'-f {framework} '
- f'-s {submission_path} '
- f'-w {workload} '
- f'-e {study_dir} '
- f'{max_steps_flag} '
- f'--num_tuning_trials {num_tuning_trials} '
- f'--rng_seed {run_seed} '
- f'{additional_requirements_path_flag}'
- '-c false '
- '-o true '
- '-i true ')
+ mount_repo_flag = (
+ '-v /home/kasimbeg/algorithmic-efficiency:/algorithmic-efficiency '
+ )
+ command = (
+ 'docker run -t -d -v /home/kasimbeg/data/:/data/ '
+ '-v /home/kasimbeg/experiment_runs/:/experiment_runs '
+ '-v /home/kasimbeg/experiment_runs/logs:/logs '
+ f'{mount_repo_flag}'
+ '--gpus all --ipc=host '
+ f'{docker_image_url} '
+ f'-d {dataset} '
+ f'-f {framework} '
+ f'-s {submission_path} '
+ f'-w {workload} '
+ f'-e {study_dir} '
+ f'{max_steps_flag} '
+ f'--num_tuning_trials {num_tuning_trials} '
+ f'--rng_seed {run_seed} '
+ f'{additional_requirements_path_flag}'
+ '-c false '
+ '-o true '
+ '-i true '
+ )
# Append tuning ruleset flags
tuning_ruleset_flags = ''
@@ -280,18 +307,19 @@ def main(_):
return_code = 0
if return_code == 0:
print(
- f'SUCCESS: container for {framework} {workload} launched successfully'
+ f'SUCCESS: container for {framework} {workload} launched successfully'
)
print(f'Command: {command}')
print(f'Results will be logged to {experiment_name}')
else:
print(
- f'Failed: container for {framework} {workload} failed with exit code {return_code}.'
+ f'Failed: container for {framework} {workload} failed with exit code {return_code}.'
)
print(f'Command: {command}')
wait_until_container_not_running()
os.system(
- "sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'") # clear caches
+ "sudo sh -c 'echo 3 > /proc/sys/vm/drop_caches'"
+ ) # clear caches
print('=' * 100)
diff --git a/scoring/utils/slurm/README.md b/scoring/utils/slurm/README.md
index ed42752dd..a8e41f04b 100644
--- a/scoring/utils/slurm/README.md
+++ b/scoring/utils/slurm/README.md
@@ -1,27 +1,37 @@
# Launching SLURM jobs with SBATCH
+
This folder contains a SLURM batch script that can be used to run jobs where each job corresponds to a training run on a given workload, training algorithm, random seed and tuning trial (if on external tuning ruleset).
To launch jobs:
-1) Generate a job config. The following command will generate a config.json.
-```
+
+1. Generate a job config. The following command will generate a config.json.
+
+```bash
python3 make_job_config.py \
--submission_path \
--tuning_search_space \
--experiment_dir $HOME/experiments/ \
--framework
```
-2) Save the config.json in the same directory you will run the sbatch script from.
-3) Copy the example sbatch script `run_jobs.sh`.
+
+2. Save the config.json in the same directory you will run the sbatch script from.
+3. Copy the example sbatch script `run_jobs.sh`.
+
- Set the task range to the number of tasks in the config.
+
```
#SBATCH --array=0-119
```
+
- Set the output and error logs directory for the SLURM logs.
+
```
#SBATCH --output=experiments///job_%A_%a.out
#SBATCH --error=experiments///job_%A_%a.err
```
+
- Update the gcp project information, docker image, config file path and bucket to save the logs to as necessary:
+
```
REPO="us-central1-docker.pkg.dev"
IMAGE="us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_main"
@@ -31,22 +41,23 @@ docker pull $IMAGE
config_file="$HOME/configs/pmap_job_config.json" # Replace with your config file path
LOGS_BUCKET="algoperf-runs-internal"
```
-4) Submit a SLURM batch job by running:
+
+4. Submit a SLURM batch job by running:
+
```
sbatch run_jobs.sh
```
-
# Set up new SLURM cluster
+
If you are setting up a new cluster, we recommend using the [HPC toolkit to set up a SLURM cluster](https://cloud.google.com/cluster-toolkit/docs/quickstarts/slurm-cluster).
To set up the new cluster:
-1) [Install the Google Cluster Toolkit](https://github.com/GoogleCloudPlatform/cluster-toolkit?tab=readme-ov-file#quickstart).
-2) Create and deploy a packer node to create a base image for the cluster nodes. See [packer builder terraform blueprint](/scoring/utils/slurm/algoperf_slurm_packer_builder.yaml).
-3) Manually update the image:
- 1) Create a VM from the Disk image created in the previous step.
- 2) Install the NVIDIA container toolkit on the VM.
- 3) Transfer the data from GCP bucket to `/opt/data`.
- 4) Create a new disk image from the VM.
-4) Create and deploy the cluster. See [cluster terraform blueprint](/scoring/utils/slurm/algoperf_slurm_cluster.yaml).
-
+1. [Install the Google Cluster Toolkit](https://github.com/GoogleCloudPlatform/cluster-toolkit?tab=readme-ov-file#quickstart).
+2. Create and deploy a packer node to create a base image for the cluster nodes. See [packer builder terraform blueprint](/scoring/utils/slurm/algoperf_slurm_packer_builder.yaml).
+3. Manually update the image:
+ 1. Create a VM from the Disk image created in the previous step.
+ 2. Install the NVIDIA container toolkit on the VM.
+ 3. Transfer the data from GCP bucket to `/opt/data`.
+ 4. Create a new disk image from the VM.
+4. Create and deploy the cluster. See [cluster terraform blueprint](/scoring/utils/slurm/algoperf_slurm_cluster.yaml).
diff --git a/scoring/utils/slurm/algoperf_slurm_cluster.yaml b/scoring/utils/slurm/algoperf_slurm_cluster.yaml
index e6c35e017..073fe98cc 100644
--- a/scoring/utils/slurm/algoperf_slurm_cluster.yaml
+++ b/scoring/utils/slurm/algoperf_slurm_cluster.yaml
@@ -32,74 +32,74 @@ vars:
# bucket: <>
deployment_groups:
-- group: primary
- modules:
- - id: network
- source: modules/network/vpc
+ - group: primary
+ modules:
+ - id: network
+ source: modules/network/vpc
- - id: homefs
- source: community/modules/file-system/nfs-server
- use: [network]
- settings:
- local_mounts: [/home]
- disk_size: 3000
- zone: $(vars.zone)
+ - id: homefs
+ source: community/modules/file-system/nfs-server
+ use: [network]
+ settings:
+ local_mounts: [/home]
+ disk_size: 3000
+ zone: $(vars.zone)
- - id: script
- source: modules/scripts/startup-script
- settings:
+ - id: script
+ source: modules/scripts/startup-script
+ settings:
-- group: cluster
- modules:
- - id: v100_nodeset
- source: community/modules/compute/schedmd-slurm-gcp-v6-nodeset
- use:
- - network
- settings:
- node_count_dynamic_max: 25 # set to 0 if you want node to live forever
- region: $(vars.region)
- zone: $(vars.zone)
- enable_placement: false
- bandwidth_tier: gvnic_enabled
- machine_type: n1-standard-64
- guest_accelerator:
- - type: nvidia-tesla-v100
- count: 8
- instance_image:
- project: $(vars.project_id)
- name: $(vars.image_name)
- instance_image_custom: true
+ - group: cluster
+ modules:
+ - id: v100_nodeset
+ source: community/modules/compute/schedmd-slurm-gcp-v6-nodeset
+ use:
+ - network
+ settings:
+ node_count_dynamic_max: 25 # set to 0 if you want node to live forever
+ region: $(vars.region)
+ zone: $(vars.zone)
+ enable_placement: false
+ bandwidth_tier: gvnic_enabled
+ machine_type: n1-standard-64
+ guest_accelerator:
+ - type: nvidia-tesla-v100
+ count: 8
+ instance_image:
+ project: $(vars.project_id)
+ name: $(vars.image_name)
+ instance_image_custom: true
- - id: v100_partition
- source: community/modules/compute/schedmd-slurm-gcp-v6-partition
- use: [v100_nodeset]
- settings:
- exclusive: false
- partition_name: v100
- is_default: true
+ - id: v100_partition
+ source: community/modules/compute/schedmd-slurm-gcp-v6-partition
+ use: [v100_nodeset]
+ settings:
+ exclusive: false
+ partition_name: v100
+ is_default: true
- - id: slurm_login
- source: community/modules/scheduler/schedmd-slurm-gcp-v6-login
- use: [network]
- settings:
- enable_login_public_ips: true
- instance_image:
- project: $(vars.project_id)
- name: $(vars.image_name)
- instance_image_custom: true
- zone: $(vars.zone)
+ - id: slurm_login
+ source: community/modules/scheduler/schedmd-slurm-gcp-v6-login
+ use: [network]
+ settings:
+ enable_login_public_ips: true
+ instance_image:
+ project: $(vars.project_id)
+ name: $(vars.image_name)
+ instance_image_custom: true
+ zone: $(vars.zone)
- - id: slurm_controller
- source: community/modules/scheduler/schedmd-slurm-gcp-v6-controller
- use:
- - network
- - v100_partition
- - homefs
- - slurm_login
- settings:
- enable_controller_public_ips: true
- instance_image:
- project: $(vars.project_id)
- name: $(vars.image_name)
- instance_image_custom: true
- region: $(vars.region)
+ - id: slurm_controller
+ source: community/modules/scheduler/schedmd-slurm-gcp-v6-controller
+ use:
+ - network
+ - v100_partition
+ - homefs
+ - slurm_login
+ settings:
+ enable_controller_public_ips: true
+ instance_image:
+ project: $(vars.project_id)
+ name: $(vars.image_name)
+ instance_image_custom: true
+ region: $(vars.region)
diff --git a/scoring/utils/slurm/algoperf_slurm_pakcer_builder.yaml b/scoring/utils/slurm/algoperf_slurm_pakcer_builder.yaml
index 286728e1d..f3b5be5dd 100644
--- a/scoring/utils/slurm/algoperf_slurm_pakcer_builder.yaml
+++ b/scoring/utils/slurm/algoperf_slurm_pakcer_builder.yaml
@@ -13,7 +13,6 @@
# limitations under the License.
---
-
blueprint_name: algoperf-slurm-packer
vars:
@@ -36,97 +35,97 @@ vars:
# bucket: <>
deployment_groups:
-- group: primary
- modules:
- - id: network
- source: modules/network/vpc
+ - group: primary
+ modules:
+ - id: network
+ source: modules/network/vpc
- - id: script
- source: modules/scripts/startup-script
- settings:
- region: $(vars.region)
- install_ansible: true
- docker:
- enabled: true
- world_writable: true
- # (TODO) Do I need this?
- configure_ssh_host_patterns:
- - 10.0.0.*
- - 10.1.0.*
- - 10.2.0.*
- - 10.3.0.*
- - 10.4.0.*
- - 10.5.0.*
- - 10.6.0.*
- - 10.7.0.*
- - $(vars.slurm_cluster_name)*
- runners:
- - type: shell
- destination: install-ml-libraries.sh
- content: |
- #!/bin/bash
- # this script is designed to execute on Slurm images published by SchedMD that:
- # - are based on Debian distribution of Linux
- # - have NVIDIA drivers pre-installed
+ - id: script
+ source: modules/scripts/startup-script
+ settings:
+ region: $(vars.region)
+ install_ansible: true
+ docker:
+ enabled: true
+ world_writable: true
+ # (TODO) Do I need this?
+ configure_ssh_host_patterns:
+ - 10.0.0.*
+ - 10.1.0.*
+ - 10.2.0.*
+ - 10.3.0.*
+ - 10.4.0.*
+ - 10.5.0.*
+ - 10.6.0.*
+ - 10.7.0.*
+ - $(vars.slurm_cluster_name)*
+ runners:
+ - type: shell
+ destination: install-ml-libraries.sh
+ content: |
+ #!/bin/bash
+ # this script is designed to execute on Slurm images published by SchedMD that:
+ # - are based on Debian distribution of Linux
+ # - have NVIDIA drivers pre-installed
- set -e -o pipefail
+ set -e -o pipefail
- echo "deb https://packages.cloud.google.com/apt google-fast-socket main" > /etc/apt/sources.list.d/google-fast-socket.list
- apt-get update --allow-releaseinfo-change
- apt-get install --assume-yes google-fast-socket
+ echo "deb https://packages.cloud.google.com/apt google-fast-socket main" > /etc/apt/sources.list.d/google-fast-socket.list
+ apt-get update --allow-releaseinfo-change
+ apt-get install --assume-yes google-fast-socket
- CONDA_BASE=/opt/conda
+ CONDA_BASE=/opt/conda
- if [ -d $CONDA_BASE ]; then
- exit 0
- fi
+ if [ -d $CONDA_BASE ]; then
+ exit 0
+ fi
- DL_DIR=\$(mktemp -d)
- cd $DL_DIR
- curl -L -O https://github.com/conda-forge/miniforge/releases/download/24.7.1-2/Miniforge3-24.7.1-2-Linux-x86_64.sh
- HOME=$DL_DIR bash Miniforge3-24.7.1-2-Linux-x86_64.sh -b -p $CONDA_BASE
- cd -
- rm -rf $DL_DIR
- unset DL_DIR
+ DL_DIR=\$(mktemp -d)
+ cd $DL_DIR
+ curl -L -O https://github.com/conda-forge/miniforge/releases/download/24.7.1-2/Miniforge3-24.7.1-2-Linux-x86_64.sh
+ HOME=$DL_DIR bash Miniforge3-24.7.1-2-Linux-x86_64.sh -b -p $CONDA_BASE
+ cd -
+ rm -rf $DL_DIR
+ unset DL_DIR
- source $CONDA_BASE/bin/activate base
- conda init --system
- conda config --system --set auto_activate_base False
- # following channel ordering is important! use strict_priority!
- conda config --system --set channel_priority strict
- conda update -n base conda --yes
+ source $CONDA_BASE/bin/activate base
+ conda init --system
+ conda config --system --set auto_activate_base False
+ # following channel ordering is important! use strict_priority!
+ conda config --system --set channel_priority strict
+ conda update -n base conda --yes
- ### create a virtual environment for tensorflow
- conda create -n tf python=3.11 --yes
- conda activate tf
- pip install tensorflow[and-cuda]==2.18.*
- pip install tensorrt==10.6.*
+ ### create a virtual environment for tensorflow
+ conda create -n tf python=3.11 --yes
+ conda activate tf
+ pip install tensorflow[and-cuda]==2.18.*
+ pip install tensorrt==10.6.*
- ### create a virtual environment for pytorch
- conda create -n pytorch python=3.11 --yes
- conda activate pytorch
- pip install torch torchvision torchaudio
+ ### create a virtual environment for pytorch
+ conda create -n pytorch python=3.11 --yes
+ conda activate pytorch
+ pip install torch torchvision torchaudio
-- group: packer
- modules:
- - id: custom-image
- source: modules/packer/custom-image
- kind: packer
- use:
- - network
- - script
- settings:
- # give VM a public IP to ensure startup script can reach public internet
- # w/o new VPC
- omit_external_ip: false
- source_image_project_id: [schedmd-slurm-public]
- # see latest in https://github.com/GoogleCloudPlatform/slurm-gcp/blob/master/docs/images.md#published-image-family
- source_image_family: slurm-gcp-6-8-debian-11
- # You can find size of source image by using following command
- # gcloud compute images describe-from-family --project schedmd-slurm-public
- disk_size: $(vars.disk_size_gb)
- image_family: $(vars.new_image.family)
- # building this image does not require a GPU-enabled VM
- machine_type: c2-standard-16
- state_timeout: 300m
- zone: $(vars.zone)
+ - group: packer
+ modules:
+ - id: custom-image
+ source: modules/packer/custom-image
+ kind: packer
+ use:
+ - network
+ - script
+ settings:
+ # give VM a public IP to ensure startup script can reach public internet
+ # w/o new VPC
+ omit_external_ip: false
+ source_image_project_id: [schedmd-slurm-public]
+ # see latest in https://github.com/GoogleCloudPlatform/slurm-gcp/blob/master/docs/images.md#published-image-family
+ source_image_family: slurm-gcp-6-8-debian-11
+ # You can find size of source image by using following command
+ # gcloud compute images describe-from-family --project schedmd-slurm-public
+ disk_size: $(vars.disk_size_gb)
+ image_family: $(vars.new_image.family)
+ # building this image does not require a GPU-enabled VM
+ machine_type: c2-standard-16
+ state_timeout: 300m
+ zone: $(vars.zone)
diff --git a/scoring/utils/slurm/config.json b/scoring/utils/slurm/config.json
index dc19e57f7..cb49f9bf4 100644
--- a/scoring/utils/slurm/config.json
+++ b/scoring/utils/slurm/config.json
@@ -1,106 +1,106 @@
{
- "0": {
- "framework": "jax",
- "workload": "imagenet_resnet",
- "dataset": "imagenet",
- "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py",
- "experiment_dir": "$HOME/experiments/jit_switch/study_0",
- "rng_seed": 411096763,
- "tuning_ruleset": "external",
- "num_tuning_trials": 1,
- "hparam_start_index": 0,
- "hparam_end_index": 1,
- "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json"
- },
- "1": {
- "framework": "jax",
- "workload": "imagenet_vit",
- "dataset": "imagenet",
- "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py",
- "experiment_dir": "$HOME/experiments/jit_switch/study_0",
- "rng_seed": -1884713130,
- "tuning_ruleset": "external",
- "num_tuning_trials": 1,
- "hparam_start_index": 0,
- "hparam_end_index": 1,
- "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json"
- },
- "2": {
- "framework": "jax",
- "workload": "fastmri",
- "dataset": "fastmri",
- "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py",
- "experiment_dir": "$HOME/experiments/jit_switch/study_0",
- "rng_seed": -214785144,
- "tuning_ruleset": "external",
- "num_tuning_trials": 1,
- "hparam_start_index": 0,
- "hparam_end_index": 1,
- "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json"
- },
- "3": {
- "framework": "jax",
- "workload": "ogbg",
- "dataset": "ogbg",
- "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py",
- "experiment_dir": "$HOME/experiments/jit_switch/study_0",
- "rng_seed": -893097833,
- "tuning_ruleset": "external",
- "num_tuning_trials": 1,
- "hparam_start_index": 0,
- "hparam_end_index": 1,
- "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json"
- },
- "4": {
- "framework": "jax",
- "workload": "wmt",
- "dataset": "wmt",
- "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py",
- "experiment_dir": "$HOME/experiments/jit_switch/study_0",
- "rng_seed": -1244182279,
- "tuning_ruleset": "external",
- "num_tuning_trials": 1,
- "hparam_start_index": 0,
- "hparam_end_index": 1,
- "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json"
- },
- "5": {
- "framework": "jax",
- "workload": "librispeech_deepspeech",
- "dataset": "librispeech",
- "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py",
- "experiment_dir": "$HOME/experiments/jit_switch/study_0",
- "rng_seed": 1546003634,
- "tuning_ruleset": "external",
- "num_tuning_trials": 1,
- "hparam_start_index": 0,
- "hparam_end_index": 1,
- "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json"
- },
- "6": {
- "framework": "jax",
- "workload": "criteo1tb",
- "dataset": "criteo1tb",
- "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py",
- "experiment_dir": "$HOME/experiments/jit_switch/study_0",
- "rng_seed": -2062333143,
- "tuning_ruleset": "external",
- "num_tuning_trials": 1,
- "hparam_start_index": 0,
- "hparam_end_index": 1,
- "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json"
- },
- "7": {
- "framework": "jax",
- "workload": "librispeech_conformer",
- "dataset": "librispeech",
- "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py",
- "experiment_dir": "$HOME/experiments/jit_switch/study_0",
- "rng_seed": 409209730,
- "tuning_ruleset": "external",
- "num_tuning_trials": 1,
- "hparam_start_index": 0,
- "hparam_end_index": 1,
- "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json"
- }
-}
\ No newline at end of file
+ "0": {
+ "framework": "jax",
+ "workload": "imagenet_resnet",
+ "dataset": "imagenet",
+ "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py",
+ "experiment_dir": "$HOME/experiments/jit_switch/study_0",
+ "rng_seed": 411096763,
+ "tuning_ruleset": "external",
+ "num_tuning_trials": 1,
+ "hparam_start_index": 0,
+ "hparam_end_index": 1,
+ "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json"
+ },
+ "1": {
+ "framework": "jax",
+ "workload": "imagenet_vit",
+ "dataset": "imagenet",
+ "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py",
+ "experiment_dir": "$HOME/experiments/jit_switch/study_0",
+ "rng_seed": -1884713130,
+ "tuning_ruleset": "external",
+ "num_tuning_trials": 1,
+ "hparam_start_index": 0,
+ "hparam_end_index": 1,
+ "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json"
+ },
+ "2": {
+ "framework": "jax",
+ "workload": "fastmri",
+ "dataset": "fastmri",
+ "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py",
+ "experiment_dir": "$HOME/experiments/jit_switch/study_0",
+ "rng_seed": -214785144,
+ "tuning_ruleset": "external",
+ "num_tuning_trials": 1,
+ "hparam_start_index": 0,
+ "hparam_end_index": 1,
+ "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json"
+ },
+ "3": {
+ "framework": "jax",
+ "workload": "ogbg",
+ "dataset": "ogbg",
+ "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py",
+ "experiment_dir": "$HOME/experiments/jit_switch/study_0",
+ "rng_seed": -893097833,
+ "tuning_ruleset": "external",
+ "num_tuning_trials": 1,
+ "hparam_start_index": 0,
+ "hparam_end_index": 1,
+ "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json"
+ },
+ "4": {
+ "framework": "jax",
+ "workload": "wmt",
+ "dataset": "wmt",
+ "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py",
+ "experiment_dir": "$HOME/experiments/jit_switch/study_0",
+ "rng_seed": -1244182279,
+ "tuning_ruleset": "external",
+ "num_tuning_trials": 1,
+ "hparam_start_index": 0,
+ "hparam_end_index": 1,
+ "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json"
+ },
+ "5": {
+ "framework": "jax",
+ "workload": "librispeech_deepspeech",
+ "dataset": "librispeech",
+ "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py",
+ "experiment_dir": "$HOME/experiments/jit_switch/study_0",
+ "rng_seed": 1546003634,
+ "tuning_ruleset": "external",
+ "num_tuning_trials": 1,
+ "hparam_start_index": 0,
+ "hparam_end_index": 1,
+ "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json"
+ },
+ "6": {
+ "framework": "jax",
+ "workload": "criteo1tb",
+ "dataset": "criteo1tb",
+ "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py",
+ "experiment_dir": "$HOME/experiments/jit_switch/study_0",
+ "rng_seed": -2062333143,
+ "tuning_ruleset": "external",
+ "num_tuning_trials": 1,
+ "hparam_start_index": 0,
+ "hparam_end_index": 1,
+ "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json"
+ },
+ "7": {
+ "framework": "jax",
+ "workload": "librispeech_conformer",
+ "dataset": "librispeech",
+ "submission_path": "reference_algorithms/paper_baselines/adamw/jax/submission.py",
+ "experiment_dir": "$HOME/experiments/jit_switch/study_0",
+ "rng_seed": 409209730,
+ "tuning_ruleset": "external",
+ "num_tuning_trials": 1,
+ "hparam_start_index": 0,
+ "hparam_end_index": 1,
+ "tuning_search_space": "reference_algorithms/paper_baselines/adamw/tuning_search_space.json"
+ }
+}
diff --git a/scoring/utils/slurm/make_job_config.py b/scoring/utils/slurm/make_job_config.py
index 116e70459..f6a1ca158 100644
--- a/scoring/utils/slurm/make_job_config.py
+++ b/scoring/utils/slurm/make_job_config.py
@@ -6,60 +6,66 @@
--experiment_dir $HOME/experiments/ \
--framework
"""
+
import json
import os
-from absl import app
-from absl import flags
import jax
+from absl import app, flags
SUBMISSION_PATH = 'submissions_algorithms/submissions/self_tuning/schedule_free_adamw_v2/submission.py'
-EXPERIMENT_DIR = 'submissions/rolling_leaderboard/self_tuning/schedule_free_adamw_v2'
+EXPERIMENT_DIR = (
+ 'submissions/rolling_leaderboard/self_tuning/schedule_free_adamw_v2'
+)
TUNING_SEARCH_SPACE = None
FRAMEWORK = 'pytorch'
TUNING_RULESET = 'self'
flags.DEFINE_string(
- 'submission_path',
- SUBMISSION_PATH,
- 'Path to submission module relative to algorithmic-efficiency dir.')
+ 'submission_path',
+ SUBMISSION_PATH,
+ 'Path to submission module relative to algorithmic-efficiency dir.',
+)
+flags.DEFINE_string(
+ 'tuning_search_space',
+ TUNING_SEARCH_SPACE,
+ 'Path to tuning search space for submission module relative to algorithmic-efficiency dir.',
+)
flags.DEFINE_string(
- 'tuning_search_space',
- TUNING_SEARCH_SPACE,
- 'Path to tuning search space for submission module relative to algorithmic-efficiency dir.'
+ 'experiment_dir',
+ EXPERIMENT_DIR,
+ 'Path to experiment dir where logs will be saved.',
)
-flags.DEFINE_string('experiment_dir',
- EXPERIMENT_DIR,
- 'Path to experiment dir where logs will be saved.')
flags.DEFINE_enum(
- 'framework',
- FRAMEWORK,
- enum_values=['jax', 'pytorch'],
- help='Can be either pytorch or jax.')
+ 'framework',
+ FRAMEWORK,
+ enum_values=['jax', 'pytorch'],
+ help='Can be either pytorch or jax.',
+)
flags.DEFINE_integer('seed', 0, 'RNG seed to to generate study seeds from.')
flags.DEFINE_enum(
- 'tuning_ruleset',
- TUNING_RULESET,
- enum_values=['external', 'self'],
- help='Which tuning ruleset to score this submission on. Can be external or self.'
+ 'tuning_ruleset',
+ TUNING_RULESET,
+ enum_values=['external', 'self'],
+ help='Which tuning ruleset to score this submission on. Can be external or self.',
)
FLAGS = flags.FLAGS
-MIN_INT = -2**(31)
-MAX_INT = 2**(31) - 1
+MIN_INT = -(2 ** (31))
+MAX_INT = 2 ** (31) - 1
NUM_TUNING_TRIALS = 5 # For external tuning ruleset
NUM_STUDIES = 3
WORKLOADS = {
- "imagenet_resnet": {"dataset": "imagenet"},
- "imagenet_vit": {"dataset": "imagenet"},
- "fastmri": {"dataset": "fastmri"},
- "ogbg": {"dataset": "ogbg"},
- "wmt": {"dataset": "wmt"},
- "librispeech_deepspeech": {"dataset": "librispeech"},
- "criteo1tb": {"dataset": "criteo1tb"},
- "librispeech_conformer": {"dataset": "librispeech"}
+ 'imagenet_resnet': {'dataset': 'imagenet'},
+ 'imagenet_vit': {'dataset': 'imagenet'},
+ 'fastmri': {'dataset': 'fastmri'},
+ 'ogbg': {'dataset': 'ogbg'},
+ 'wmt': {'dataset': 'wmt'},
+ 'librispeech_deepspeech': {'dataset': 'librispeech'},
+ 'criteo1tb': {'dataset': 'criteo1tb'},
+ 'librispeech_conformer': {'dataset': 'librispeech'},
}
@@ -81,7 +87,7 @@ def main(_):
print(seed)
# Add job
job = {}
- study_dir = os.path.join(FLAGS.experiment_dir, f"study_{study_index}")
+ study_dir = os.path.join(FLAGS.experiment_dir, f'study_{study_index}')
job['framework'] = FLAGS.framework
job['workload'] = workload
job['dataset'] = WORKLOADS[workload]['dataset']
@@ -103,7 +109,7 @@ def main(_):
print(seed)
# Add job
job = {}
- study_dir = os.path.join(FLAGS.experiment_dir, f"study_{study_index}")
+ study_dir = os.path.join(FLAGS.experiment_dir, f'study_{study_index}')
job['framework'] = FLAGS.framework
job['workload'] = workload
job['dataset'] = WORKLOADS[workload]['dataset']
@@ -119,7 +125,7 @@ def main(_):
# Convert job array to dict with job indices
job_dict = {}
for i, job in enumerate(jobs):
- job_dict[f"{i}"] = job
+ job_dict[f'{i}'] = job
with open('config.json', 'w') as f:
json.dump(job_dict, f, indent=4)
diff --git a/scoring/utils/workload_metadata_external_tuning.json b/scoring/utils/workload_metadata_external_tuning.json
index c205d28b2..c7d4ae195 100644
--- a/scoring/utils/workload_metadata_external_tuning.json
+++ b/scoring/utils/workload_metadata_external_tuning.json
@@ -1,34 +1,34 @@
{
- "imagenet_resnet": {
- "max_steps": 186666,
- "dataset": "imagenet"
- },
- "imagenet_vit": {
- "max_steps": 186666,
- "dataset": "imagenet"
- },
- "fastmri": {
- "max_steps": 36189,
- "dataset": "fastmri"
- },
- "ogbg": {
- "max_steps": 80000,
- "dataset": "ogbg"
- },
- "wmt": {
- "max_steps": 133333,
- "dataset": "wmt"
- },
- "librispeech_deepspeech": {
- "max_steps": 48000,
- "dataset": "librispeech"
- },
- "criteo1tb": {
- "max_steps": 10666,
- "dataset": "criteo1tb"
- },
- "librispeech_conformer": {
- "max_steps": 80000,
- "dataset": "librispeech"
- }
- }
\ No newline at end of file
+ "imagenet_resnet": {
+ "max_steps": 186666,
+ "dataset": "imagenet"
+ },
+ "imagenet_vit": {
+ "max_steps": 186666,
+ "dataset": "imagenet"
+ },
+ "fastmri": {
+ "max_steps": 36189,
+ "dataset": "fastmri"
+ },
+ "ogbg": {
+ "max_steps": 80000,
+ "dataset": "ogbg"
+ },
+ "wmt": {
+ "max_steps": 133333,
+ "dataset": "wmt"
+ },
+ "librispeech_deepspeech": {
+ "max_steps": 48000,
+ "dataset": "librispeech"
+ },
+ "criteo1tb": {
+ "max_steps": 10666,
+ "dataset": "criteo1tb"
+ },
+ "librispeech_conformer": {
+ "max_steps": 80000,
+ "dataset": "librispeech"
+ }
+}
diff --git a/scoring/utils/workload_metadata_self_tuning.json b/scoring/utils/workload_metadata_self_tuning.json
index 105d5c52f..9d3e6b93d 100644
--- a/scoring/utils/workload_metadata_self_tuning.json
+++ b/scoring/utils/workload_metadata_self_tuning.json
@@ -1,34 +1,34 @@
{
- "imagenet_resnet": {
- "max_steps": 559998,
- "dataset": "imagenet"
- },
- "imagenet_vit": {
- "max_steps": 559998,
- "dataset": "imagenet"
- },
- "fastmri": {
- "max_steps": 108567,
- "dataset": "fastmri"
- },
- "ogbg": {
- "max_steps": 240000,
- "dataset": "ogbg"
- },
- "wmt": {
- "max_steps": 399999,
- "dataset": "wmt"
- },
- "librispeech_deepspeech": {
- "max_steps": 144000,
- "dataset": "librispeech"
- },
- "criteo1tb": {
- "max_steps": 31998,
- "dataset": "criteo1tb"
- },
- "librispeech_conformer": {
- "max_steps": 240000,
- "dataset": "librispeech"
- }
- }
\ No newline at end of file
+ "imagenet_resnet": {
+ "max_steps": 559998,
+ "dataset": "imagenet"
+ },
+ "imagenet_vit": {
+ "max_steps": 559998,
+ "dataset": "imagenet"
+ },
+ "fastmri": {
+ "max_steps": 108567,
+ "dataset": "fastmri"
+ },
+ "ogbg": {
+ "max_steps": 240000,
+ "dataset": "ogbg"
+ },
+ "wmt": {
+ "max_steps": 399999,
+ "dataset": "wmt"
+ },
+ "librispeech_deepspeech": {
+ "max_steps": 144000,
+ "dataset": "librispeech"
+ },
+ "criteo1tb": {
+ "max_steps": 31998,
+ "dataset": "criteo1tb"
+ },
+ "librispeech_conformer": {
+ "max_steps": 240000,
+ "dataset": "librispeech"
+ }
+}
From f4ae9beac42dbcad30b77cfc386616ab29214061 Mon Sep 17 00:00:00 2001
From: Frank Schneider
Date: Mon, 23 Jun 2025 10:19:28 +0200
Subject: [PATCH 103/123] Format reference_algorithms/
---
.../cifar/cifar_jax/submission.py | 178 +-
.../cifar/cifar_pytorch/submission.py | 129 +-
.../mnist/mnist_jax/submission.py | 156 +-
.../mnist/mnist_pytorch/submission.py | 105 +-
.../adafactor/jax/sharded_adafactor.py | 275 +--
.../adafactor/jax/submission.py | 205 +-
.../adafactor/pytorch/submission.py | 288 +--
.../paper_baselines/adamw/jax/submission.py | 204 +-
.../adamw/pytorch/submission.py | 157 +-
.../paper_baselines/lamb/jax/submission.py | 204 +-
.../lamb/pytorch/submission.py | 228 +-
.../momentum/jax/submission.py | 220 +-
.../momentum/pytorch/submission.py | 171 +-
.../paper_baselines/nadamw/jax/submission.py | 269 ++-
.../nadamw/pytorch/submission.py | 274 ++-
.../nesterov/jax/submission.py | 220 +-
.../nesterov/pytorch/submission.py | 171 +-
.../paper_baselines/sam/jax/submission.py | 256 ++-
.../paper_baselines/sam/pytorch/submission.py | 191 +-
.../shampoo/jax/distributed_shampoo.py | 2007 +++++++++--------
.../paper_baselines/shampoo/jax/submission.py | 211 +-
.../cosine_warmup.py | 30 +-
.../criteo1tb/tuning_search_space.json | 20 +-
.../data_selection.py | 17 +-
.../target_setting_algorithms/jax_adamw.py | 52 +-
.../target_setting_algorithms/jax_momentum.py | 79 +-
.../target_setting_algorithms/jax_nadamw.py | 96 +-
.../target_setting_algorithms/jax_nesterov.py | 79 +-
.../jax_submission_base.py | 148 +-
.../pytorch_adamw.py | 47 +-
.../pytorch_momentum.py | 58 +-
.../pytorch_nadamw.py | 172 +-
.../pytorch_nesterov.py | 58 +-
.../pytorch_submission_base.py | 101 +-
34 files changed, 3915 insertions(+), 3161 deletions(-)
diff --git a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py
index 3d8e35eaa..37f74ac45 100644
--- a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py
+++ b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py
@@ -19,26 +19,29 @@ def get_batch_size(workload_name):
def cosine_decay(lr, step, total_steps):
- ratio = jnp.maximum(0., step / total_steps)
- mult = 0.5 * (1. + jnp.cos(jnp.pi * ratio))
+ ratio = jnp.maximum(0.0, step / total_steps)
+ mult = 0.5 * (1.0 + jnp.cos(jnp.pi * ratio))
return mult * lr
-def create_learning_rate_fn(hparams: spec.Hyperparameters,
- steps_per_epoch: int):
+def create_learning_rate_fn(
+ hparams: spec.Hyperparameters, steps_per_epoch: int
+):
"""Create learning rate schedule."""
- base_learning_rate = hparams.learning_rate * get_batch_size('cifar') / 128.
+ base_learning_rate = hparams.learning_rate * get_batch_size('cifar') / 128.0
warmup_fn = optax.linear_schedule(
- init_value=0.,
- end_value=base_learning_rate,
- transition_steps=hparams.warmup_epochs * steps_per_epoch)
+ init_value=0.0,
+ end_value=base_learning_rate,
+ transition_steps=hparams.warmup_epochs * steps_per_epoch,
+ )
cosine_epochs = max(hparams.num_epochs - hparams.warmup_epochs, 1)
cosine_fn = optax.cosine_decay_schedule(
- init_value=base_learning_rate,
- decay_steps=cosine_epochs * steps_per_epoch)
+ init_value=base_learning_rate, decay_steps=cosine_epochs * steps_per_epoch
+ )
schedule_fn = optax.join_schedules(
- schedules=[warmup_fn, cosine_fn],
- boundaries=[hparams.warmup_epochs * steps_per_epoch])
+ schedules=[warmup_fn, cosine_fn],
+ boundaries=[hparams.warmup_epochs * steps_per_epoch],
+ )
return schedule_fn
@@ -46,51 +49,59 @@ def optimizer(hyperparameters: spec.Hyperparameters, num_train_examples: int):
steps_per_epoch = num_train_examples // get_batch_size('cifar')
learning_rate_fn = create_learning_rate_fn(hyperparameters, steps_per_epoch)
opt_init_fn, opt_update_fn = optax.sgd(
- nesterov=True,
- momentum=hyperparameters.momentum,
- learning_rate=learning_rate_fn)
+ nesterov=True,
+ momentum=hyperparameters.momentum,
+ learning_rate=learning_rate_fn,
+ )
return opt_init_fn, opt_update_fn
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
del model_params
del model_state
del rng
- params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple),
- workload.param_shapes)
- opt_init_fn, opt_update_fn = optimizer(hyperparameters,
- workload.num_train_examples)
+ params_zeros_like = jax.tree.map(
+ lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes
+ )
+ opt_init_fn, opt_update_fn = optimizer(
+ hyperparameters, workload.num_train_examples
+ )
optimizer_state = opt_init_fn(params_zeros_like)
return jax_utils.replicate(optimizer_state), opt_update_fn
@functools.partial(
- jax.pmap,
- axis_name='batch',
- in_axes=(None, None, 0, 0, 0, None, 0, 0),
- static_broadcasted_argnums=(0, 1))
-def pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- hyperparameters,
- batch,
- rng):
-
+ jax.pmap,
+ axis_name='batch',
+ in_axes=(None, None, 0, 0, 0, None, 0, 0),
+ static_broadcasted_argnums=(0, 1),
+)
+def pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ hyperparameters,
+ batch,
+ rng,
+):
def _loss_fn(params):
"""loss function used for training."""
logits, new_model_state = workload.model_fn(
- params,
- batch,
- model_state,
- spec.ForwardPassMode.TRAIN,
- rng,
- update_batch_norm=True)
+ params,
+ batch,
+ model_state,
+ spec.ForwardPassMode.TRAIN,
+ rng,
+ update_batch_norm=True,
+ )
loss_dict = workload.loss_fn(batch['targets'], logits)
loss = loss_dict['summed'] / loss_dict['n_valid_examples']
weight_penalty_params = jax.tree_util.tree_leaves(params)
@@ -102,25 +113,27 @@ def _loss_fn(params):
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
(_, new_model_state), grad = grad_fn(current_param_container)
grad = lax.pmean(grad, axis_name='batch')
- updates, new_optimizer_state = opt_update_fn(grad, optimizer_state,
- current_param_container)
+ updates, new_optimizer_state = opt_update_fn(
+ grad, optimizer_state, current_param_container
+ )
updated_params = optax.apply_updates(current_param_container, updates)
return new_optimizer_state, updated_params, new_model_state
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -130,21 +143,30 @@ def update_params(
optimizer_state, opt_update_fn = optimizer_state
per_device_rngs = jax.random.split(rng, jax.local_device_count())
new_optimizer_state, new_params, new_model_state = pmapped_train_step(
- workload, opt_update_fn, model_state, optimizer_state,
- current_param_container, hyperparameters, batch, per_device_rngs)
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ hyperparameters,
+ batch,
+ per_device_rngs,
+ )
return (new_optimizer_state, opt_update_fn), new_params, new_model_state
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -158,14 +180,16 @@ def prepare_for_eval(workload: spec.Workload,
# Not allowed to update the model parameters, hyperparameters, global step, or
# optimzier state.
-def data_selection(workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+def data_selection(
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
diff --git a/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py
index d8b91f83a..5fd51c3b2 100644
--- a/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py
+++ b/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py
@@ -16,56 +16,63 @@ def get_batch_size(workload_name):
return batch_sizes[workload_name]
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
del workload
del model_state
del rng
- base_lr = hyperparameters.learning_rate * get_batch_size('cifar') / 128.
+ base_lr = hyperparameters.learning_rate * get_batch_size('cifar') / 128.0
optimizer_state = {
- 'optimizer':
- torch.optim.SGD(
- model_params.parameters(),
- lr=base_lr,
- momentum=hyperparameters.momentum,
- weight_decay=hyperparameters.l2),
+ 'optimizer': torch.optim.SGD(
+ model_params.parameters(),
+ lr=base_lr,
+ momentum=hyperparameters.momentum,
+ weight_decay=hyperparameters.l2,
+ ),
}
scheduler1 = LinearLR(
- optimizer_state['optimizer'],
- start_factor=1e-5,
- end_factor=1.,
- total_iters=hyperparameters.warmup_epochs)
+ optimizer_state['optimizer'],
+ start_factor=1e-5,
+ end_factor=1.0,
+ total_iters=hyperparameters.warmup_epochs,
+ )
cosine_epochs = max(
- hyperparameters.num_epochs - hyperparameters.warmup_epochs, 1)
+ hyperparameters.num_epochs - hyperparameters.warmup_epochs, 1
+ )
scheduler2 = CosineAnnealingLR(
- optimizer_state['optimizer'], T_max=cosine_epochs)
+ optimizer_state['optimizer'], T_max=cosine_epochs
+ )
optimizer_state['scheduler'] = SequentialLR(
- optimizer_state['optimizer'],
- schedulers=[scheduler1, scheduler2],
- milestones=[hyperparameters.warmup_epochs])
+ optimizer_state['optimizer'],
+ schedulers=[scheduler1, scheduler2],
+ milestones=[hyperparameters.warmup_epochs],
+ )
return optimizer_state
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del current_params_types
del hyperparameters
@@ -78,15 +85,17 @@ def update_params(
optimizer_state['optimizer'].zero_grad()
logits_batch, new_model_state = workload.model_fn(
- params=current_model,
- augmented_and_preprocessed_input_batch=batch,
- model_state=model_state,
- mode=spec.ForwardPassMode.TRAIN,
- rng=rng,
- update_batch_norm=True)
+ params=current_model,
+ augmented_and_preprocessed_input_batch=batch,
+ model_state=model_state,
+ mode=spec.ForwardPassMode.TRAIN,
+ rng=rng,
+ update_batch_norm=True,
+ )
loss_dict = workload.loss_fn(
- label_batch=batch['targets'], logits_batch=logits_batch)
+ label_batch=batch['targets'], logits_batch=logits_batch
+ )
loss = loss_dict['summed'] / loss_dict['n_valid_examples']
loss.backward()
@@ -99,16 +108,18 @@ def update_params(
return (optimizer_state, current_param_container, new_model_state)
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -122,14 +133,16 @@ def prepare_for_eval(workload: spec.Workload,
# Not allowed to update the model parameters, hyperparameters, global step, or
# optimzier state.
-def data_selection(workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+def data_selection(
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
diff --git a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py
index c1f54597d..3ef97577f 100644
--- a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py
+++ b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py
@@ -18,50 +18,59 @@ def get_batch_size(workload_name):
return batch_sizes[workload_name]
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
del model_params
del model_state
del rng
- params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple),
- workload.param_shapes)
+ params_zeros_like = jax.tree.map(
+ lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes
+ )
opt_init_fn, opt_update_fn = optax.chain(
- optax.scale_by_adam(
- b1=1.0 - hyperparameters.one_minus_beta_1,
- b2=0.999,
- eps=hyperparameters.epsilon),
- optax.scale(-hyperparameters.learning_rate))
+ optax.scale_by_adam(
+ b1=1.0 - hyperparameters.one_minus_beta_1,
+ b2=0.999,
+ eps=hyperparameters.epsilon,
+ ),
+ optax.scale(-hyperparameters.learning_rate),
+ )
return jax_utils.replicate(opt_init_fn(params_zeros_like)), opt_update_fn
# We need to jax.pmap here instead of inside update_params because the latter
# would recompile the function every step.
@functools.partial(
- jax.pmap,
- axis_name='batch',
- in_axes=(None, None, 0, 0, None, 0, 0, 0),
- static_broadcasted_argnums=(0, 1))
-def pmapped_update_params(workload: spec.Workload,
- opt_update_fn,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- optimizer_state: spec.OptimizerState,
- rng: spec.RandomState) -> spec.UpdateReturn:
+ jax.pmap,
+ axis_name='batch',
+ in_axes=(None, None, 0, 0, None, 0, 0, 0),
+ static_broadcasted_argnums=(0, 1),
+)
+def pmapped_update_params(
+ workload: spec.Workload,
+ opt_update_fn,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ optimizer_state: spec.OptimizerState,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
del hyperparameters
def loss_fn(params):
logits_batch, new_model_state = workload.model_fn(
- params=params,
- augmented_and_preprocessed_input_batch=batch,
- model_state=model_state,
- mode=spec.ForwardPassMode.TRAIN,
- rng=rng,
- update_batch_norm=True)
+ params=params,
+ augmented_and_preprocessed_input_batch=batch,
+ model_state=model_state,
+ mode=spec.ForwardPassMode.TRAIN,
+ rng=rng,
+ update_batch_norm=True,
+ )
loss_dict = workload.loss_fn(batch['targets'], logits_batch)
loss = loss_dict['summed'] / loss_dict['n_valid_examples']
return loss, new_model_state
@@ -69,25 +78,27 @@ def loss_fn(params):
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(_, new_model_state), grad = grad_fn(current_param_container)
grad = lax.pmean(grad, axis_name='batch')
- updates, new_optimizer_state = opt_update_fn(grad, optimizer_state,
- current_param_container)
+ updates, new_optimizer_state = opt_update_fn(
+ grad, optimizer_state, current_param_container
+ )
updated_params = optax.apply_updates(current_param_container, updates)
return new_optimizer_state, updated_params, new_model_state
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -98,27 +109,30 @@ def update_params(
per_device_rngs = jax.random.split(rng, jax.local_device_count())
optimizer_state, opt_update_fn = optimizer_state
new_optimizer_state, updated_params, new_model_state = pmapped_update_params(
- workload,
- opt_update_fn,
- current_param_container,
- model_state,
- hyperparameters,
- batch,
- optimizer_state,
- per_device_rngs)
+ workload,
+ opt_update_fn,
+ current_param_container,
+ model_state,
+ hyperparameters,
+ batch,
+ optimizer_state,
+ per_device_rngs,
+ )
return (new_optimizer_state, opt_update_fn), updated_params, new_model_state
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -132,14 +146,16 @@ def prepare_for_eval(workload: spec.Workload,
# Not allowed to update the model parameters, hyperparameters, global step, or
# optimzier state.
-def data_selection(workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+def data_selection(
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
diff --git a/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py b/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py
index dedd96793..9940fca6e 100644
--- a/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py
+++ b/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py
@@ -13,38 +13,41 @@ def get_batch_size(workload_name):
return batch_sizes[workload_name]
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
del model_state
del workload
del rng
optimizer_state = {
- 'optimizer':
- torch.optim.Adam(
- model_params.parameters(),
- lr=hyperparameters.learning_rate,
- betas=(1.0 - hyperparameters.one_minus_beta_1, 0.999),
- eps=hyperparameters.epsilon),
+ 'optimizer': torch.optim.Adam(
+ model_params.parameters(),
+ lr=hyperparameters.learning_rate,
+ betas=(1.0 - hyperparameters.one_minus_beta_1, 0.999),
+ eps=hyperparameters.epsilon,
+ ),
}
return optimizer_state
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del hyperparameters
del loss_type
@@ -59,15 +62,17 @@ def update_params(
param.grad = None
output, new_model_state = workload.model_fn(
- params=current_model,
- augmented_and_preprocessed_input_batch=batch,
- model_state=model_state,
- mode=spec.ForwardPassMode.TRAIN,
- rng=rng,
- update_batch_norm=True)
+ params=current_model,
+ augmented_and_preprocessed_input_batch=batch,
+ model_state=model_state,
+ mode=spec.ForwardPassMode.TRAIN,
+ rng=rng,
+ update_batch_norm=True,
+ )
loss_dict = workload.loss_fn(
- label_batch=batch['targets'], logits_batch=output)
+ label_batch=batch['targets'], logits_batch=output
+ )
loss = loss_dict['summed'] / loss_dict['n_valid_examples']
loss.backward()
optimizer_state['optimizer'].step()
@@ -75,16 +80,18 @@ def update_params(
return (optimizer_state, current_param_container, new_model_state)
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -98,14 +105,16 @@ def prepare_for_eval(workload: spec.Workload,
# Not allowed to update the model parameters, hyperparameters, global step, or
# optimzier state.
-def data_selection(workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+def data_selection(
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
diff --git a/reference_algorithms/paper_baselines/adafactor/jax/sharded_adafactor.py b/reference_algorithms/paper_baselines/adafactor/jax/sharded_adafactor.py
index ff98464ae..c83f14a13 100644
--- a/reference_algorithms/paper_baselines/adafactor/jax/sharded_adafactor.py
+++ b/reference_algorithms/paper_baselines/adafactor/jax/sharded_adafactor.py
@@ -38,8 +38,9 @@
NestedHParams = Any
-def to_quantized(fvalue: JTensor,
- quantized_dtype: jnp.dtype) -> Tuple[JTensor, JTensor]:
+def to_quantized(
+ fvalue: JTensor, quantized_dtype: jnp.dtype
+) -> Tuple[JTensor, JTensor]:
"""Converts floating point values `fvalues` to quantized values.
We use a very simple quantization scheme where the range is symmetric around
@@ -82,16 +83,17 @@ def to_quantized(fvalue: JTensor,
# We first decide the scale.
if fvalue.ndim < 1:
raise ValueError(
- f'Input array {fvalue} must have a strictly positive number of '
- 'dimensions.')
+ f'Input array {fvalue} must have a strictly positive number of '
+ 'dimensions.'
+ )
max_abs = jnp.max(jnp.abs(fvalue), axis=0)
bucket_size = max_abs / num_buckets
bs_expanded = bucket_size[jnp.newaxis, ...]
# To avoid divide by 0.0
- bs_nonzero = jnp.where(bs_expanded > 0.0,
- bs_expanded,
- jnp.ones_like(bs_expanded))
+ bs_nonzero = jnp.where(
+ bs_expanded > 0.0, bs_expanded, jnp.ones_like(bs_expanded)
+ )
ratio = fvalue / bs_nonzero
# We use rounding to remove bias.
quantized = jnp.round(ratio)
@@ -128,8 +130,8 @@ def adafactor_decay_rate_adam(beta2: float, step_counter: JTensor) -> JTensor:
"""
step = step_counter
beta2 = jnp.array(beta2, dtype=jnp.float32)
- t = step + 1.
- return beta2 * (1. - jnp.power(beta2, t - 1.)) / (1. - jnp.power(beta2, t))
+ t = step + 1.0
+ return beta2 * (1.0 - jnp.power(beta2, t - 1.0)) / (1.0 - jnp.power(beta2, t))
def adafactor_decay_rate_pow(exponent: float, step_counter: JTensor) -> JTensor:
@@ -145,7 +147,7 @@ def adafactor_decay_rate_pow(exponent: float, step_counter: JTensor) -> JTensor:
"""
step = step_counter
exponent = jnp.array(exponent, dtype=jnp.float32)
- return 1. - jnp.power((step + 1.), -exponent)
+ return 1.0 - jnp.power((step + 1.0), -exponent)
def reduce_mean(array: JTensor) -> JTensor:
@@ -187,6 +189,7 @@ def reduce_rms(array: JTensor) -> JTensor:
@dataclasses.dataclass(frozen=True)
class _ShardedAdafactorUpdateResult:
"""Structure containing per-variable info for Adafactor."""
+
update: Optional[Any]
m: Optional[Any]
m_scale: Optional[Any]
@@ -197,6 +200,7 @@ class _ShardedAdafactorUpdateResult:
class ShardedAdafactorState(NamedTuple):
"""Overall state of the ShardedAdafactor optimizer."""
+
count: JTensor
m: Optional[NestedJTensor]
m_scale: Optional[NestedJTensor]
@@ -208,27 +212,29 @@ class ShardedAdafactorState(NamedTuple):
class _ShardedAdafactorHelper:
"""Helper class to implement optax-based sharded Adafactor."""
- def __init__(self,
- learning_rate: optax.Schedule,
- weight_decay: Optional[float],
- layerwise_adaptation: bool,
- decay_method: str,
- decay_adam: float,
- decay_pow: float,
- beta1: float,
- clip_threshold: Optional[float],
- factored: bool,
- epsilon1_grad_sq_reg: float,
- quantized_dtype: jnp.dtype,
- respect_skip_lp_regularization: bool,
- exclude_from_layerwise_adaptation: Optional[List[str]],
- per_var_learning_summary: bool,
- sort_factored_second_moment_dims: bool,
- min_dim_size_to_factor: int,
- multiply_by_parameter_scale: bool,
- epsilon2_param_scale_reg: float,
- maybe_inf_to_nan: bool,
- nesterov: bool) -> None:
+ def __init__(
+ self,
+ learning_rate: optax.Schedule,
+ weight_decay: Optional[float],
+ layerwise_adaptation: bool,
+ decay_method: str,
+ decay_adam: float,
+ decay_pow: float,
+ beta1: float,
+ clip_threshold: Optional[float],
+ factored: bool,
+ epsilon1_grad_sq_reg: float,
+ quantized_dtype: jnp.dtype,
+ respect_skip_lp_regularization: bool,
+ exclude_from_layerwise_adaptation: Optional[List[str]],
+ per_var_learning_summary: bool,
+ sort_factored_second_moment_dims: bool,
+ min_dim_size_to_factor: int,
+ multiply_by_parameter_scale: bool,
+ epsilon2_param_scale_reg: float,
+ maybe_inf_to_nan: bool,
+ nesterov: bool,
+ ) -> None:
"""Constructor. See ShardedAdafactor() below."""
self._learning_rate = learning_rate
@@ -315,12 +321,13 @@ def should_store_momentum_in_qint(self, shape):
def to_state(self, count, result_tree):
"""Maps from a tree of (factored) values to separate trees of values."""
return ShardedAdafactorState(
- count=count,
- m=jax.tree.map(lambda o: o.m, result_tree),
- m_scale=jax.tree.map(lambda o: o.m_scale, result_tree),
- vr=jax.tree.map(lambda o: o.vr, result_tree),
- vc=jax.tree.map(lambda o: o.vc, result_tree),
- v=jax.tree.map(lambda o: o.v, result_tree))
+ count=count,
+ m=jax.tree.map(lambda o: o.m, result_tree),
+ m_scale=jax.tree.map(lambda o: o.m_scale, result_tree),
+ vr=jax.tree.map(lambda o: o.vr, result_tree),
+ vc=jax.tree.map(lambda o: o.vc, result_tree),
+ v=jax.tree.map(lambda o: o.v, result_tree),
+ )
def init(self, param):
"""Initializes the optimizer state for a given param."""
@@ -353,12 +360,13 @@ def init(self, param):
else:
output_v = jnp.zeros(shape, dtype=jnp.float32)
return _ShardedAdafactorUpdateResult(
- update=output_update,
- m=output_m,
- m_scale=output_m_scale,
- vr=output_vr,
- vc=output_vc,
- v=output_v)
+ update=output_update,
+ m=output_m,
+ m_scale=output_m_scale,
+ vr=output_vr,
+ vc=output_vc,
+ v=output_v,
+ )
def inf_to_nan(self, array):
"""Converting Infinity values to the more sticky NaN."""
@@ -386,16 +394,9 @@ def parameter_scale(self, var):
"""
return jnp.maximum(reduce_rms(var), jnp.asarray(self._epsilon2, var.dtype))
- def compute_var_and_slot_update(self,
- count,
- grad,
- m,
- m_scale,
- vr,
- vc,
- v,
- param,
- var_name=None):
+ def compute_var_and_slot_update(
+ self, count, grad, m, m_scale, vr, vc, v, param, var_name=None
+ ):
"""Computes the var and optimizer slots updates for a single variable."""
# We can probably skip this step
grad = grad.astype(jnp.float32)
@@ -434,7 +435,7 @@ def compute_var_and_slot_update(self,
update_scale += grad_squared_mean * 1e-30
# END HACK
- mixing_rate = 1. - decay_rate
+ mixing_rate = 1.0 - decay_rate
shape = param.shape
output_m = jnp.zeros((1,))
@@ -449,18 +450,23 @@ def compute_var_and_slot_update(self,
# reduce_mean().
vr_axis, vc_axis = factored_second_moment_dims
grad_squared_row_mean = self.inf_to_nan(
- jnp.mean(grad_squared, axis=vr_axis))
+ jnp.mean(grad_squared, axis=vr_axis)
+ )
grad_squared_col_mean = self.inf_to_nan(
- jnp.mean(grad_squared, axis=vc_axis))
+ jnp.mean(grad_squared, axis=vc_axis)
+ )
new_vr = decay_rate * vr + mixing_rate * grad_squared_row_mean
new_vc = decay_rate * vc + mixing_rate * grad_squared_col_mean
output_vr = new_vr
output_vc = new_vc
long_term_mean = jnp.mean(new_vr, axis=-1, keepdims=True)
- r_factor = 1. / jnp.sqrt(new_vr / long_term_mean)
- c_factor = 1. / jnp.sqrt(new_vc)
- x = grad * jnp.expand_dims(r_factor, vr_axis) * jnp.expand_dims(
- c_factor, vc_axis)
+ r_factor = 1.0 / jnp.sqrt(new_vr / long_term_mean)
+ c_factor = 1.0 / jnp.sqrt(new_vc)
+ x = (
+ grad
+ * jnp.expand_dims(r_factor, vr_axis)
+ * jnp.expand_dims(c_factor, vc_axis)
+ )
else:
# v with sharding annotation.
new_v = decay_rate * v + mixing_rate * grad_squared
@@ -468,7 +474,7 @@ def compute_var_and_slot_update(self,
x = grad / jnp.sqrt(new_v)
if self._clip_threshold is not None:
- clipping_denom = jnp.maximum(1., reduce_rms(x) / self._clip_threshold)
+ clipping_denom = jnp.maximum(1.0, reduce_rms(x) / self._clip_threshold)
clipping_denom = self.inf_to_nan(clipping_denom)
x /= clipping_denom
@@ -481,7 +487,7 @@ def compute_var_and_slot_update(self,
m = to_float(m, m_scale)
if self._nesterov:
subtrahend_original = subtrahend
- subtrahend = self._beta1 * m + (1. - self._beta1) * subtrahend
+ subtrahend = self._beta1 * m + (1.0 - self._beta1) * subtrahend
subtrahend = self.inf_to_nan(subtrahend)
if self._quantized_dtype == jnp.bfloat16:
new_m = subtrahend.astype(jnp.bfloat16)
@@ -496,8 +502,8 @@ def compute_var_and_slot_update(self,
if self._nesterov:
subtrahend = (
- self._beta1 * subtrahend +
- (1.0 - self._beta1) * subtrahend_original)
+ self._beta1 * subtrahend + (1.0 - self._beta1) * subtrahend_original
+ )
if self._weight_decay is not None:
# Apply decoupled weight decay to be consistent with AdamW.
@@ -527,43 +533,45 @@ def compute_var_and_slot_update(self,
g_norm = reduce_rms(subtrahend / update_scale) + self._epsilon1
ratio = w_norm / g_norm
ratio = jnp.where(
- jnp.greater(w_norm, 0),
- jnp.where(jnp.greater(g_norm, 0), (w_norm / g_norm), 1.0),
- 1.0)
+ jnp.greater(w_norm, 0),
+ jnp.where(jnp.greater(g_norm, 0), (w_norm / g_norm), 1.0),
+ 1.0,
+ )
subtrahend *= ratio
return _ShardedAdafactorUpdateResult(
- update=-subtrahend,
- m=output_m,
- m_scale=output_m_scale,
- vr=output_vr,
- vc=output_vc,
- v=output_v)
+ update=-subtrahend,
+ m=output_m,
+ m_scale=output_m_scale,
+ vr=output_vr,
+ vc=output_vc,
+ v=output_v,
+ )
def sharded_adafactor(
- learning_rate: optax.Schedule,
- weight_decay: Optional[Union[float, Dict[str, float]]] = None,
- layerwise_adaptation: bool = False,
- decay_method: str = 'adam',
- decay_adam: float = 0.99,
- decay_pow: float = 0.,
- beta1: float = 0.9,
- clip_threshold: Optional[float] = 1.,
- factored: bool = True,
- epsilon1_grad_sq_reg: float = 1e-30,
- quantized_dtype: jnp.dtype = jnp.int8,
- respect_skip_lp_regularization: bool = False,
- exclude_from_layerwise_adaptation: Optional[List[str]] = None,
- per_var_learning_summary: bool = False,
- sort_factored_second_moment_dims: bool = False,
- # min_dim_size_to_factor is only used when
- # sort_factored_second_moment_dims=True.
- min_dim_size_to_factor: int = 128,
- multiply_by_parameter_scale: bool = False,
- epsilon2_param_scale_reg: float = 1e-3,
- maybe_inf_to_nan: bool = True,
- nesterov: bool = False,
+ learning_rate: optax.Schedule,
+ weight_decay: Optional[Union[float, Dict[str, float]]] = None,
+ layerwise_adaptation: bool = False,
+ decay_method: str = 'adam',
+ decay_adam: float = 0.99,
+ decay_pow: float = 0.0,
+ beta1: float = 0.9,
+ clip_threshold: Optional[float] = 1.0,
+ factored: bool = True,
+ epsilon1_grad_sq_reg: float = 1e-30,
+ quantized_dtype: jnp.dtype = jnp.int8,
+ respect_skip_lp_regularization: bool = False,
+ exclude_from_layerwise_adaptation: Optional[List[str]] = None,
+ per_var_learning_summary: bool = False,
+ sort_factored_second_moment_dims: bool = False,
+ # min_dim_size_to_factor is only used when
+ # sort_factored_second_moment_dims=True.
+ min_dim_size_to_factor: int = 128,
+ multiply_by_parameter_scale: bool = False,
+ epsilon2_param_scale_reg: float = 1e-3,
+ maybe_inf_to_nan: bool = True,
+ nesterov: bool = False,
) -> optax.GradientTransformation:
"""AdaFactor optimizer that supports SPMD sharding.
@@ -638,53 +646,60 @@ def sharded_adafactor(
assert decay_pow >= 0
assert learning_rate is not None
assert decay_method == 'adam' or decay_method == 'pow', (
- f'decay_method: {decay_method} not supported. Supported methods are '
- '"pow", or "adam".')
+ f'decay_method: {decay_method} not supported. Supported methods are '
+ '"pow", or "adam".'
+ )
sharded_adafactor_helper = _ShardedAdafactorHelper(
- learning_rate=learning_rate,
- weight_decay=weight_decay,
- layerwise_adaptation=layerwise_adaptation,
- decay_method=decay_method,
- decay_adam=decay_adam,
- decay_pow=decay_pow,
- beta1=beta1,
- clip_threshold=clip_threshold,
- factored=factored,
- epsilon1_grad_sq_reg=epsilon1_grad_sq_reg,
- quantized_dtype=quantized_dtype,
- respect_skip_lp_regularization=respect_skip_lp_regularization,
- exclude_from_layerwise_adaptation=exclude_from_layerwise_adaptation,
- per_var_learning_summary=per_var_learning_summary,
- sort_factored_second_moment_dims=sort_factored_second_moment_dims,
- min_dim_size_to_factor=min_dim_size_to_factor,
- multiply_by_parameter_scale=multiply_by_parameter_scale,
- epsilon2_param_scale_reg=epsilon2_param_scale_reg,
- maybe_inf_to_nan=maybe_inf_to_nan,
- nesterov=nesterov)
+ learning_rate=learning_rate,
+ weight_decay=weight_decay,
+ layerwise_adaptation=layerwise_adaptation,
+ decay_method=decay_method,
+ decay_adam=decay_adam,
+ decay_pow=decay_pow,
+ beta1=beta1,
+ clip_threshold=clip_threshold,
+ factored=factored,
+ epsilon1_grad_sq_reg=epsilon1_grad_sq_reg,
+ quantized_dtype=quantized_dtype,
+ respect_skip_lp_regularization=respect_skip_lp_regularization,
+ exclude_from_layerwise_adaptation=exclude_from_layerwise_adaptation,
+ per_var_learning_summary=per_var_learning_summary,
+ sort_factored_second_moment_dims=sort_factored_second_moment_dims,
+ min_dim_size_to_factor=min_dim_size_to_factor,
+ multiply_by_parameter_scale=multiply_by_parameter_scale,
+ epsilon2_param_scale_reg=epsilon2_param_scale_reg,
+ maybe_inf_to_nan=maybe_inf_to_nan,
+ nesterov=nesterov,
+ )
def init_fn(params):
"""Initializes the optimizer's state."""
return sharded_adafactor_helper.to_state(
- jnp.zeros([], jnp.int32),
- jax.tree.map(sharded_adafactor_helper.init, params))
+ jnp.zeros([], jnp.int32),
+ jax.tree.map(sharded_adafactor_helper.init, params),
+ )
def update_fn(updates, state, params=None):
if params is None:
raise ValueError(
- 'You are using a transformation that requires the current value of '
- 'parameters, but you are not passing `params` when calling `update`.')
+ 'You are using a transformation that requires the current value of '
+ 'parameters, but you are not passing `params` when calling `update`.'
+ )
compute_var_and_slot_update_fn = functools.partial(
- sharded_adafactor_helper.compute_var_and_slot_update, state.count)
- output = jax.tree.map(compute_var_and_slot_update_fn,
- updates,
- state.m,
- state.m_scale,
- state.vr,
- state.vc,
- state.v,
- params)
+ sharded_adafactor_helper.compute_var_and_slot_update, state.count
+ )
+ output = jax.tree.map(
+ compute_var_and_slot_update_fn,
+ updates,
+ state.m,
+ state.m_scale,
+ state.vr,
+ state.vc,
+ state.v,
+ params,
+ )
updates = jax.tree.map(lambda o: o.update, output)
count_plus_one = state.count + jnp.array(1, jnp.int32)
updated_states = sharded_adafactor_helper.to_state(count_plus_one, output)
diff --git a/reference_algorithms/paper_baselines/adafactor/jax/submission.py b/reference_algorithms/paper_baselines/adafactor/jax/submission.py
index 1833ab8af..898de35eb 100644
--- a/reference_algorithms/paper_baselines/adafactor/jax/submission.py
+++ b/reference_algorithms/paper_baselines/adafactor/jax/submission.py
@@ -10,17 +10,20 @@
import optax
from algoperf import spec
-from reference_algorithms.paper_baselines.adafactor.jax.sharded_adafactor import \
- sharded_adafactor
+from reference_algorithms.paper_baselines.adafactor.jax.sharded_adafactor import (
+ sharded_adafactor,
+)
_GRAD_CLIP_EPS = 1e-6
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates an Adafactor optimizer and a learning rate schedule."""
del model_params
del model_state
@@ -30,99 +33,113 @@ def jax_cosine_warmup(step_hint: int, hyperparameters):
# Create learning rate schedule.
warmup_steps = int(hyperparameters.warmup_factor * step_hint)
warmup_fn = optax.linear_schedule(
- init_value=0.,
- end_value=hyperparameters.learning_rate,
- transition_steps=warmup_steps)
+ init_value=0.0,
+ end_value=hyperparameters.learning_rate,
+ transition_steps=warmup_steps,
+ )
cosine_steps = max(step_hint - warmup_steps, 1)
cosine_fn = optax.cosine_decay_schedule(
- init_value=hyperparameters.learning_rate, decay_steps=cosine_steps)
+ init_value=hyperparameters.learning_rate, decay_steps=cosine_steps
+ )
schedule_fn = optax.join_schedules(
- schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps])
+ schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]
+ )
return schedule_fn
# Create optimizer + LR schedule.
lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters)
opt_init_fn, opt_update_fn = sharded_adafactor(
- learning_rate=lr_schedule_fn,
- beta1=1.0 - hyperparameters.one_minus_beta1,
- weight_decay=hyperparameters.weight_decay)
- params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple),
- workload.param_shapes)
+ learning_rate=lr_schedule_fn,
+ beta1=1.0 - hyperparameters.one_minus_beta1,
+ weight_decay=hyperparameters.weight_decay,
+ )
+ params_zeros_like = jax.tree.map(
+ lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes
+ )
optimizer_state = opt_init_fn(params_zeros_like)
return jax_utils.replicate(optimizer_state), opt_update_fn
@functools.partial(
- jax.pmap,
- axis_name='batch',
- in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
- static_broadcasted_argnums=(0, 1),
- donate_argnums=(2, 3, 4))
-def pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- batch,
- rng,
- grad_clip,
- label_smoothing):
-
+ jax.pmap,
+ axis_name='batch',
+ in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
+ static_broadcasted_argnums=(0, 1),
+ donate_argnums=(2, 3, 4),
+)
+def pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ rng,
+ grad_clip,
+ label_smoothing,
+):
def _loss_fn(params):
"""Loss function used for training."""
logits, new_model_state = workload.model_fn(
- params,
- batch,
- model_state,
- spec.ForwardPassMode.TRAIN,
- rng,
- update_batch_norm=True)
+ params,
+ batch,
+ model_state,
+ spec.ForwardPassMode.TRAIN,
+ rng,
+ update_batch_norm=True,
+ )
loss_dict = workload.loss_fn(
- label_batch=batch['targets'],
- logits_batch=logits,
- mask_batch=batch.get('weights'),
- label_smoothing=label_smoothing)
+ label_batch=batch['targets'],
+ logits_batch=logits,
+ mask_batch=batch.get('weights'),
+ label_smoothing=label_smoothing,
+ )
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
return summed_loss, (n_valid_examples, new_model_state)
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
(summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn(
- current_param_container)
+ current_param_container
+ )
# Get correct global mean loss and grad.
(summed_loss, n_valid_examples, grad) = lax.psum(
- (summed_loss, n_valid_examples, grad), axis_name='batch')
+ (summed_loss, n_valid_examples, grad), axis_name='batch'
+ )
loss = summed_loss / n_valid_examples
grad = jax.tree.map(lambda x: x / n_valid_examples, grad)
grad_norm = jnp.sqrt(
- sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)))
+ sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))
+ )
if grad_clip is not None:
grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS)
grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0)
grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad)
- updates, new_optimizer_state = opt_update_fn(grad, optimizer_state,
- current_param_container)
+ updates, new_optimizer_state = opt_update_fn(
+ grad, optimizer_state, current_param_container
+ )
updated_params = optax.apply_updates(current_param_container, updates)
return new_optimizer_state, updated_params, new_model_state, loss, grad_norm
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -139,37 +156,43 @@ def update_params(
grad_clip = hyperparameters.grad_clip
else:
grad_clip = None
- outputs = pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- batch,
- per_device_rngs,
- grad_clip,
- label_smoothing)
+ outputs = pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ per_device_rngs,
+ grad_clip,
+ label_smoothing,
+ )
new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs
# Log loss, grad_norm.
if global_step % 100 == 0 and workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
- {
- 'loss': loss[0],
- 'grad_norm': grad_norm[0],
- }, global_step)
+ {
+ 'loss': loss[0],
+ 'grad_norm': grad_norm[0],
+ },
+ global_step,
+ )
return (new_optimizer_state, opt_update_fn), new_params, new_model_state
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -205,14 +228,16 @@ def get_batch_size(workload_name):
raise ValueError(f'Unsupported workload name: {workload_name}.')
-def data_selection(workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+def data_selection(
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
diff --git a/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py b/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py
index 7aa457a25..dd831566f 100644
--- a/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py
+++ b/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py
@@ -16,36 +16,40 @@
USE_PYTORCH_DDP = pytorch_setup()[0]
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates an Adafactor optimizer and a learning rate schedule."""
del model_state
del rng
# Create optimizer.
optimizer_state = {
- 'optimizer':
- Adafactor(
- model_params.parameters(),
- lr=hyperparameters.learning_rate,
- beta1=1 - hyperparameters.one_minus_beta1,
- weight_decay=hyperparameters.weight_decay),
+ 'optimizer': Adafactor(
+ model_params.parameters(),
+ lr=hyperparameters.learning_rate,
+ beta1=1 - hyperparameters.one_minus_beta1,
+ weight_decay=hyperparameters.weight_decay,
+ ),
}
optimizer = optimizer_state['optimizer']
warmup = LinearLR(
- optimizer,
- start_factor=1e-10,
- end_factor=1.,
- total_iters=hyperparameters.warmup_steps)
+ optimizer,
+ start_factor=1e-10,
+ end_factor=1.0,
+ total_iters=hyperparameters.warmup_steps,
+ )
cosine_steps = max(workload.step_hint - hyperparameters.warmup_steps, 1)
cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps)
optimizer_state['scheduler'] = SequentialLR(
- optimizer,
- schedulers=[warmup, cosine_decay],
- milestones=[hyperparameters.warmup_steps])
+ optimizer,
+ schedulers=[warmup, cosine_decay],
+ milestones=[hyperparameters.warmup_steps],
+ )
return optimizer_state
@@ -54,56 +58,56 @@ class Adafactor(torch.optim.Optimizer):
src/transformers/optimization.py#L386"""
def __init__(
- self,
- params,
- lr=None,
- beta1=0.9,
- decay_adam=0.99,
- weight_decay=0.0,
+ self,
+ params,
+ lr=None,
+ beta1=0.9,
+ decay_adam=0.99,
+ weight_decay=0.0,
):
defaults = dict(
- lr=lr,
- beta1=beta1,
- decay_adam=decay_adam,
- weight_decay=weight_decay,
- decay_pow=0.0,
- layerwise_adaptation=False,
- decay_method='adam',
- clip_threshold=1.0,
- factored=True,
- epsilon1_grad_sq_reg=1e-30,
- respect_skip_lp_regularization=False,
- exclude_from_layerwise_adaptation=None,
- per_var_learning_summary=False,
- sort_factored_second_moment_dims=False,
- # Unused because sort_factored_second_moment_dims=False.
- min_dim_size_to_factor=128,
- multiply_by_parameter_scale=False,
- # Unused because multiply_by_parameter_scale=False.
- epsilon2_param_scale_reg=1e-3,
- maybe_inf_to_nan=True,
+ lr=lr,
+ beta1=beta1,
+ decay_adam=decay_adam,
+ weight_decay=weight_decay,
+ decay_pow=0.0,
+ layerwise_adaptation=False,
+ decay_method='adam',
+ clip_threshold=1.0,
+ factored=True,
+ epsilon1_grad_sq_reg=1e-30,
+ respect_skip_lp_regularization=False,
+ exclude_from_layerwise_adaptation=None,
+ per_var_learning_summary=False,
+ sort_factored_second_moment_dims=False,
+ # Unused because sort_factored_second_moment_dims=False.
+ min_dim_size_to_factor=128,
+ multiply_by_parameter_scale=False,
+ # Unused because multiply_by_parameter_scale=False.
+ epsilon2_param_scale_reg=1e-3,
+ maybe_inf_to_nan=True,
)
super().__init__(params, defaults)
def inf_to_nan(self, group, x):
- if group["maybe_inf_to_nan"]:
+ if group['maybe_inf_to_nan']:
x = torch.nan_to_num(x, nan=torch.nan, posinf=torch.nan, neginf=torch.nan)
return x
def step(self, closure=None):
"""
- Performs a single optimization step
- Arguments:
- closure (callable, optional): A closure that reevaluates the model
- and returns the loss.
- """
+ Performs a single optimization step
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
inf_to_nan = partial(self.inf_to_nan, group)
- for p in group["params"]:
+ for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
@@ -111,7 +115,7 @@ def step(self, closure=None):
if grad.dtype in {torch.float16, torch.bfloat16}:
grad = grad.float()
if grad.is_sparse:
- raise RuntimeError("Adafactor does not support sparse gradients.")
+ raise RuntimeError('Adafactor does not support sparse gradients.')
state = self.state[p]
grad_shape = grad.shape
@@ -120,51 +124,54 @@ def step(self, closure=None):
# State Initialization
if len(state) == 0:
- state["step"] = 0
- state["exp_avg"] = torch.zeros_like(grad)
+ state['step'] = 0
+ state['exp_avg'] = torch.zeros_like(grad)
if factored:
- state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
- state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] +
- grad_shape[-1:]).to(grad)
+ state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1]).to(grad)
+ state['exp_avg_sq_col'] = torch.zeros(
+ grad_shape[:-2] + grad_shape[-1:]
+ ).to(grad)
else:
- state["exp_avg_sq"] = torch.zeros_like(grad)
+ state['exp_avg_sq'] = torch.zeros_like(grad)
else:
- state["exp_avg"] = state["exp_avg"].to(grad)
+ state['exp_avg'] = state['exp_avg'].to(grad)
if factored:
- state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
- state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
+ state['exp_avg_sq_row'] = state['exp_avg_sq_row'].to(grad)
+ state['exp_avg_sq_col'] = state['exp_avg_sq_col'].to(grad)
else:
- state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
+ state['exp_avg_sq'] = state['exp_avg_sq'].to(grad)
p_data_fp32 = p.data
if p.data.dtype in {torch.float16, torch.bfloat16}:
p_data_fp32 = p_data_fp32.float()
- state["step"] += 1
- lr = group["lr"]
- beta1 = group["beta1"]
- beta2 = group["decay_adam"]
+ state['step'] += 1
+ lr = group['lr']
+ beta1 = group['beta1']
+ beta2 = group['decay_adam']
- t = state["step"]
- beta2t = beta2 * (1. - beta2**(t - 1.)) / (1. - beta2**t)
+ t = state['step']
+ beta2t = beta2 * (1.0 - beta2 ** (t - 1.0)) / (1.0 - beta2**t)
- exp_avg_sq_update = (grad**2) + group["epsilon1_grad_sq_reg"]
+ exp_avg_sq_update = (grad**2) + group['epsilon1_grad_sq_reg']
if factored:
- exp_avg_sq_row = state["exp_avg_sq_row"]
- exp_avg_sq_col = state["exp_avg_sq_col"]
+ exp_avg_sq_row = state['exp_avg_sq_row']
+ exp_avg_sq_col = state['exp_avg_sq_col']
exp_avg_sq_row.mul_(beta2t).add_(
- exp_avg_sq_update.mean(dim=-1), alpha=1.0 - beta2t)
+ exp_avg_sq_update.mean(dim=-1), alpha=1.0 - beta2t
+ )
exp_avg_sq_col.mul_(beta2t).add_(
- exp_avg_sq_update.mean(dim=-2), alpha=1.0 - beta2t)
+ exp_avg_sq_update.mean(dim=-2), alpha=1.0 - beta2t
+ )
r_factor = inf_to_nan(
- exp_avg_sq_row /
- exp_avg_sq_row.mean(dim=-1, keepdim=True)).unsqueeze(-1)
+ exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)
+ ).unsqueeze(-1)
c_factor = inf_to_nan(exp_avg_sq_col).unsqueeze(-2)
denom = r_factor * c_factor
else:
- exp_avg_sq = state["exp_avg_sq"]
+ exp_avg_sq = state['exp_avg_sq']
exp_avg_sq.mul_(beta2t).add_(exp_avg_sq_update, alpha=1.0 - beta2t)
denom = exp_avg_sq
@@ -172,15 +179,16 @@ def step(self, closure=None):
denom = denom.sqrt()
update = grad / denom
# Clip the update based on RMS.
- clipping_denom = inf_to_nan(torch.square(update).mean().sqrt() \
- /group["clip_threshold"]).clamp(min=1.0)
+ clipping_denom = inf_to_nan(
+ torch.square(update).mean().sqrt() / group['clip_threshold']
+ ).clamp(min=1.0)
update = update / clipping_denom * lr
# Momentum
- exp_avg = state["exp_avg"]
+ exp_avg = state['exp_avg']
exp_avg.mul_(beta1).add_(update, alpha=1 - beta1)
- if group["weight_decay"] != 0:
- p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * lr)
+ if group['weight_decay'] != 0:
+ p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * lr)
p_data_fp32.add_(-exp_avg)
@@ -191,18 +199,19 @@ def step(self, closure=None):
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -214,22 +223,26 @@ def update_params(
optimizer_state['optimizer'].zero_grad()
logits_batch, new_model_state = workload.model_fn(
- params=current_model,
- augmented_and_preprocessed_input_batch=batch,
- model_state=model_state,
- mode=spec.ForwardPassMode.TRAIN,
- rng=rng,
- update_batch_norm=True)
+ params=current_model,
+ augmented_and_preprocessed_input_batch=batch,
+ model_state=model_state,
+ mode=spec.ForwardPassMode.TRAIN,
+ rng=rng,
+ update_batch_norm=True,
+ )
label_smoothing = (
- hyperparameters.label_smoothing if hasattr(hyperparameters,
- 'label_smoothing') else 0.0)
+ hyperparameters.label_smoothing
+ if hasattr(hyperparameters, 'label_smoothing')
+ else 0.0
+ )
loss_dict = workload.loss_fn(
- label_batch=batch['targets'],
- logits_batch=logits_batch,
- mask_batch=batch.get('weights'),
- label_smoothing=label_smoothing)
+ label_batch=batch['targets'],
+ logits_batch=logits_batch,
+ mask_batch=batch.get('weights'),
+ label_smoothing=label_smoothing,
+ )
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
if USE_PYTORCH_DDP:
@@ -243,12 +256,14 @@ def update_params(
with torch.no_grad():
parameters = [p for p in current_model.parameters() if p.grad is not None]
grad_norm = torch.norm(
- torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2)
+ torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2
+ )
if hasattr(hyperparameters, 'grad_clip'):
grad_clip = hyperparameters.grad_clip
torch.nn.utils.clip_grad_norm_(
- current_model.parameters(), max_norm=grad_clip)
+ current_model.parameters(), max_norm=grad_clip
+ )
optimizer_state['optimizer'].step()
optimizer_state['scheduler'].step()
@@ -256,28 +271,34 @@ def update_params(
if global_step <= 100 or global_step % 500 == 0:
if workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
- {
- 'loss': loss.item(),
- 'grad_norm': grad_norm.item(),
- }, global_step)
- logging.info('%d) loss = %0.3f, grad_norm = %0.3f',
- global_step,
- loss.item(),
- grad_norm.item())
+ {
+ 'loss': loss.item(),
+ 'grad_norm': grad_norm.item(),
+ },
+ global_step,
+ )
+ logging.info(
+ '%d) loss = %0.3f, grad_norm = %0.3f',
+ global_step,
+ loss.item(),
+ grad_norm.item(),
+ )
return (optimizer_state, current_param_container, new_model_state)
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -314,14 +335,15 @@ def get_batch_size(workload_name):
def data_selection(
- workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Tuple[spec.Tensor, spec.Tensor, spec.Tensor]:
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Tuple[spec.Tensor, spec.Tensor, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py
index dde41fa6d..52c8d5ee2 100644
--- a/reference_algorithms/paper_baselines/adamw/jax/submission.py
+++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py
@@ -14,11 +14,13 @@
_GRAD_CLIP_EPS = 1e-6
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates an AdamW optimizer and a learning rate schedule."""
del model_params
del model_state
@@ -28,101 +30,115 @@ def jax_cosine_warmup(step_hint: int, hyperparameters):
# Create learning rate schedule.
warmup_steps = int(hyperparameters.warmup_factor * step_hint)
warmup_fn = optax.linear_schedule(
- init_value=0.,
- end_value=hyperparameters.learning_rate,
- transition_steps=warmup_steps)
+ init_value=0.0,
+ end_value=hyperparameters.learning_rate,
+ transition_steps=warmup_steps,
+ )
cosine_steps = max(step_hint - warmup_steps, 1)
cosine_fn = optax.cosine_decay_schedule(
- init_value=hyperparameters.learning_rate, decay_steps=cosine_steps)
+ init_value=hyperparameters.learning_rate, decay_steps=cosine_steps
+ )
schedule_fn = optax.join_schedules(
- schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps])
+ schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]
+ )
return schedule_fn
# Create optimizer + LR schedule.
lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters)
opt_init_fn, opt_update_fn = optax.adamw(
- learning_rate=lr_schedule_fn,
- b1=1.0 - hyperparameters.one_minus_beta1,
- b2=hyperparameters.beta2,
- eps=1e-8,
- weight_decay=hyperparameters.weight_decay)
- params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple),
- workload.param_shapes)
+ learning_rate=lr_schedule_fn,
+ b1=1.0 - hyperparameters.one_minus_beta1,
+ b2=hyperparameters.beta2,
+ eps=1e-8,
+ weight_decay=hyperparameters.weight_decay,
+ )
+ params_zeros_like = jax.tree.map(
+ lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes
+ )
optimizer_state = opt_init_fn(params_zeros_like)
return jax_utils.replicate(optimizer_state), opt_update_fn
@functools.partial(
- jax.pmap,
- axis_name='batch',
- in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
- static_broadcasted_argnums=(0, 1),
- donate_argnums=(2, 3, 4))
-def pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- batch,
- rng,
- grad_clip,
- label_smoothing):
-
+ jax.pmap,
+ axis_name='batch',
+ in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
+ static_broadcasted_argnums=(0, 1),
+ donate_argnums=(2, 3, 4),
+)
+def pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ rng,
+ grad_clip,
+ label_smoothing,
+):
def _loss_fn(params):
"""Loss function used for training."""
logits, new_model_state = workload.model_fn(
- params,
- batch,
- model_state,
- spec.ForwardPassMode.TRAIN,
- rng,
- update_batch_norm=True)
+ params,
+ batch,
+ model_state,
+ spec.ForwardPassMode.TRAIN,
+ rng,
+ update_batch_norm=True,
+ )
loss_dict = workload.loss_fn(
- label_batch=batch['targets'],
- logits_batch=logits,
- mask_batch=batch.get('weights'),
- label_smoothing=label_smoothing)
+ label_batch=batch['targets'],
+ logits_batch=logits,
+ mask_batch=batch.get('weights'),
+ label_smoothing=label_smoothing,
+ )
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
return summed_loss, (n_valid_examples, new_model_state)
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
(summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn(
- current_param_container)
+ current_param_container
+ )
# Get correct global mean loss and grad.
(summed_loss, n_valid_examples, grad) = lax.psum(
- (summed_loss, n_valid_examples, grad), axis_name='batch')
+ (summed_loss, n_valid_examples, grad), axis_name='batch'
+ )
loss = summed_loss / n_valid_examples
grad = jax.tree.map(lambda x: x / n_valid_examples, grad)
grad_norm = jnp.sqrt(
- sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)))
+ sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))
+ )
if grad_clip is not None:
grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS)
grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0)
grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad)
- updates, new_optimizer_state = opt_update_fn(grad, optimizer_state,
- current_param_container)
+ updates, new_optimizer_state = opt_update_fn(
+ grad, optimizer_state, current_param_container
+ )
updated_params = optax.apply_updates(current_param_container, updates)
return new_optimizer_state, updated_params, new_model_state, loss, grad_norm
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -139,37 +155,43 @@ def update_params(
grad_clip = hyperparameters.grad_clip
else:
grad_clip = None
- outputs = pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- batch,
- per_device_rngs,
- grad_clip,
- label_smoothing)
+ outputs = pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ per_device_rngs,
+ grad_clip,
+ label_smoothing,
+ )
new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs
# Log loss, grad_norm.
if global_step % 100 == 0 and workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
- {
- 'loss': loss[0],
- 'grad_norm': grad_norm[0],
- }, global_step)
+ {
+ 'loss': loss[0],
+ 'grad_norm': grad_norm[0],
+ },
+ global_step,
+ )
return (new_optimizer_state, opt_update_fn), new_params, new_model_state
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -209,14 +231,16 @@ def get_batch_size(workload_name):
raise ValueError(f'Unsupported workload name: {workload_name}.')
-def data_selection(workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+def data_selection(
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
diff --git a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py
index 21d9b6b57..faefcd254 100644
--- a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py
+++ b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py
@@ -15,55 +15,60 @@
USE_PYTORCH_DDP = pytorch_setup()[0]
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates an AdamW optimizer and a learning rate schedule."""
del model_state
del rng
optimizer_state = {
- 'optimizer':
- torch.optim.AdamW(
- model_params.parameters(),
- lr=hyperparameters.learning_rate,
- betas=(1.0 - hyperparameters.one_minus_beta1,
- hyperparameters.beta2),
- eps=1e-8,
- weight_decay=hyperparameters.weight_decay,
- fused=False),
+ 'optimizer': torch.optim.AdamW(
+ model_params.parameters(),
+ lr=hyperparameters.learning_rate,
+ betas=(1.0 - hyperparameters.one_minus_beta1, hyperparameters.beta2),
+ eps=1e-8,
+ weight_decay=hyperparameters.weight_decay,
+ fused=False,
+ ),
}
def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer):
warmup_steps = hyperparameters.warmup_factor * step_hint
warmup = LinearLR(
- optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps)
+ optimizer, start_factor=1e-10, end_factor=1.0, total_iters=warmup_steps
+ )
cosine_steps = max(step_hint - warmup_steps, 1)
cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps)
return SequentialLR(
- optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps])
+ optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]
+ )
optimizer_state['scheduler'] = pytorch_cosine_warmup(
- workload.step_hint, hyperparameters, optimizer_state['optimizer'])
+ workload.step_hint, hyperparameters, optimizer_state['optimizer']
+ )
return optimizer_state
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -75,22 +80,26 @@ def update_params(
optimizer_state['optimizer'].zero_grad()
logits_batch, new_model_state = workload.model_fn(
- params=current_model,
- augmented_and_preprocessed_input_batch=batch,
- model_state=model_state,
- mode=spec.ForwardPassMode.TRAIN,
- rng=rng,
- update_batch_norm=True)
+ params=current_model,
+ augmented_and_preprocessed_input_batch=batch,
+ model_state=model_state,
+ mode=spec.ForwardPassMode.TRAIN,
+ rng=rng,
+ update_batch_norm=True,
+ )
label_smoothing = (
- hyperparameters.label_smoothing if hasattr(hyperparameters,
- 'label_smoothing') else 0.0)
+ hyperparameters.label_smoothing
+ if hasattr(hyperparameters, 'label_smoothing')
+ else 0.0
+ )
loss_dict = workload.loss_fn(
- label_batch=batch['targets'],
- logits_batch=logits_batch,
- mask_batch=batch.get('weights'),
- label_smoothing=label_smoothing)
+ label_batch=batch['targets'],
+ logits_batch=logits_batch,
+ mask_batch=batch.get('weights'),
+ label_smoothing=label_smoothing,
+ )
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
if USE_PYTORCH_DDP:
@@ -104,7 +113,8 @@ def update_params(
if hasattr(hyperparameters, 'grad_clip'):
grad_clip = hyperparameters.grad_clip
torch.nn.utils.clip_grad_norm_(
- current_model.parameters(), max_norm=grad_clip)
+ current_model.parameters(), max_norm=grad_clip
+ )
optimizer_state['optimizer'].step()
optimizer_state['scheduler'].step()
@@ -113,31 +123,38 @@ def update_params(
with torch.no_grad():
parameters = [p for p in current_model.parameters() if p.grad is not None]
grad_norm = torch.norm(
- torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2)
+ torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2
+ )
if workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
- {
- 'loss': loss.item(),
- 'grad_norm': grad_norm.item(),
- }, global_step)
- logging.info('%d) loss = %0.3f, grad_norm = %0.3f',
- global_step,
- loss.item(),
- grad_norm.item())
+ {
+ 'loss': loss.item(),
+ 'grad_norm': grad_norm.item(),
+ },
+ global_step,
+ )
+ logging.info(
+ '%d) loss = %0.3f, grad_norm = %0.3f',
+ global_step,
+ loss.item(),
+ grad_norm.item(),
+ )
return (optimizer_state, current_param_container, new_model_state)
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -177,14 +194,16 @@ def get_batch_size(workload_name):
raise ValueError(f'Unsupported workload name: {workload_name}.')
-def data_selection(workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+def data_selection(
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
diff --git a/reference_algorithms/paper_baselines/lamb/jax/submission.py b/reference_algorithms/paper_baselines/lamb/jax/submission.py
index 70e305514..eedbbfc37 100644
--- a/reference_algorithms/paper_baselines/lamb/jax/submission.py
+++ b/reference_algorithms/paper_baselines/lamb/jax/submission.py
@@ -21,11 +21,13 @@ def scale_by_learning_rate(learning_rate, flip_sign=True):
return optax.scale(m * learning_rate)
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a LAMB optimizer and a learning rate schedule."""
del model_params
del model_state
@@ -35,61 +37,70 @@ def jax_cosine_warmup(step_hint: int, hyperparameters):
# Create learning rate schedule.
warmup_steps = int(hyperparameters.warmup_factor * step_hint)
warmup_fn = optax.linear_schedule(
- init_value=0.,
- end_value=hyperparameters.learning_rate,
- transition_steps=warmup_steps)
+ init_value=0.0,
+ end_value=hyperparameters.learning_rate,
+ transition_steps=warmup_steps,
+ )
cosine_steps = max(step_hint - warmup_steps, 1)
cosine_fn = optax.cosine_decay_schedule(
- init_value=hyperparameters.learning_rate, decay_steps=cosine_steps)
+ init_value=hyperparameters.learning_rate, decay_steps=cosine_steps
+ )
schedule_fn = optax.join_schedules(
- schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps])
+ schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]
+ )
return schedule_fn
# Create optimizer + LR schedule.
lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters)
opt_init_fn, opt_update_fn = optax.lamb(
- learning_rate=lr_schedule_fn,
- b1=1 - hyperparameters.one_minus_beta1,
- b2=hyperparameters.beta2,
- eps=1e-8,
- weight_decay=hyperparameters.weight_decay)
- params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple),
- workload.param_shapes)
+ learning_rate=lr_schedule_fn,
+ b1=1 - hyperparameters.one_minus_beta1,
+ b2=hyperparameters.beta2,
+ eps=1e-8,
+ weight_decay=hyperparameters.weight_decay,
+ )
+ params_zeros_like = jax.tree.map(
+ lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes
+ )
optimizer_state = opt_init_fn(params_zeros_like)
return jax_utils.replicate(optimizer_state), opt_update_fn
@functools.partial(
- jax.pmap,
- axis_name='batch',
- in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
- static_broadcasted_argnums=(0, 1),
- donate_argnums=(2, 3, 4))
-def pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- batch,
- rng,
- grad_clip,
- label_smoothing):
-
+ jax.pmap,
+ axis_name='batch',
+ in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
+ static_broadcasted_argnums=(0, 1),
+ donate_argnums=(2, 3, 4),
+)
+def pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ rng,
+ grad_clip,
+ label_smoothing,
+):
def _loss_fn(params):
"""Loss function used for training."""
logits, new_model_state = workload.model_fn(
- params,
- batch,
- model_state,
- spec.ForwardPassMode.TRAIN,
- rng,
- update_batch_norm=True)
+ params,
+ batch,
+ model_state,
+ spec.ForwardPassMode.TRAIN,
+ rng,
+ update_batch_norm=True,
+ )
loss_dict = workload.loss_fn(
- label_batch=batch['targets'],
- logits_batch=logits,
- mask_batch=batch.get('weights'),
- label_smoothing=label_smoothing)
+ label_batch=batch['targets'],
+ logits_batch=logits,
+ mask_batch=batch.get('weights'),
+ label_smoothing=label_smoothing,
+ )
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
@@ -97,40 +108,45 @@ def _loss_fn(params):
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
(summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn(
- current_param_container)
+ current_param_container
+ )
# Get correct global mean loss and grad.
(summed_loss, n_valid_examples, grad) = lax.psum(
- (summed_loss, n_valid_examples, grad), axis_name='batch')
+ (summed_loss, n_valid_examples, grad), axis_name='batch'
+ )
loss = summed_loss / n_valid_examples
grad = jax.tree.map(lambda x: x / n_valid_examples, grad)
grad_norm = jnp.sqrt(
- sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)))
+ sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))
+ )
if grad_clip is not None:
grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS)
grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0)
grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad)
- updates, new_optimizer_state = opt_update_fn(grad, optimizer_state,
- current_param_container)
+ updates, new_optimizer_state = opt_update_fn(
+ grad, optimizer_state, current_param_container
+ )
updated_params = optax.apply_updates(current_param_container, updates)
return new_optimizer_state, updated_params, new_model_state, loss, grad_norm
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -147,37 +163,43 @@ def update_params(
grad_clip = hyperparameters.grad_clip
else:
grad_clip = None
- outputs = pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- batch,
- per_device_rngs,
- grad_clip,
- label_smoothing)
+ outputs = pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ per_device_rngs,
+ grad_clip,
+ label_smoothing,
+ )
new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs
# Log loss, grad_norm.
if global_step % 100 == 0 and workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
- {
- 'loss': loss[0],
- 'grad_norm': grad_norm[0],
- }, global_step)
+ {
+ 'loss': loss[0],
+ 'grad_norm': grad_norm[0],
+ },
+ global_step,
+ )
return (new_optimizer_state, opt_update_fn), new_params, new_model_state
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -213,14 +235,16 @@ def get_batch_size(workload_name):
raise ValueError(f'Unsupported workload name: {workload_name}.')
-def data_selection(workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+def data_selection(
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
diff --git a/reference_algorithms/paper_baselines/lamb/pytorch/submission.py b/reference_algorithms/paper_baselines/lamb/pytorch/submission.py
index c1c6cec0a..5cda59d6f 100644
--- a/reference_algorithms/paper_baselines/lamb/pytorch/submission.py
+++ b/reference_algorithms/paper_baselines/lamb/pytorch/submission.py
@@ -15,13 +15,9 @@
# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py
class LAMB(torch.optim.Optimizer):
-
- def __init__(self,
- params,
- lr=1e-3,
- betas=(0.9, 0.999),
- eps=1e-8,
- weight_decay=0.0):
+ def __init__(
+ self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0
+ ):
if not 0.0 <= lr:
raise ValueError(f'Invalid learning rate: {lr}')
if not 0.0 <= eps:
@@ -39,7 +35,8 @@ def __setstate__(self, state):
super().__setstate__(state)
state_values = list(self.state.values())
step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
- state_values[0]['step'])
+ state_values[0]['step']
+ )
if not step_is_tensor:
for s in state_values:
s['step'] = torch.tensor(float(s['step']))
@@ -74,48 +71,53 @@ def step(self, closure=None):
# State initialization
if len(state) == 0:
- state['step'] = torch.tensor(0.)
+ state['step'] = torch.tensor(0.0)
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(
- p, memory_format=torch.preserve_format)
+ p, memory_format=torch.preserve_format
+ )
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(
- p, memory_format=torch.preserve_format)
+ p, memory_format=torch.preserve_format
+ )
exp_avgs.append(state['exp_avg'])
exp_avg_sqs.append(state['exp_avg_sq'])
state_steps.append(state['step'])
lamb(
- params_with_grad,
- grads,
- exp_avgs,
- exp_avg_sqs,
- state_steps,
- beta1=beta1,
- beta2=beta2,
- lr=group['lr'],
- weight_decay=group['weight_decay'],
- eps=group['eps'])
+ params_with_grad,
+ grads,
+ exp_avgs,
+ exp_avg_sqs,
+ state_steps,
+ beta1=beta1,
+ beta2=beta2,
+ lr=group['lr'],
+ weight_decay=group['weight_decay'],
+ eps=group['eps'],
+ )
return loss
-def lamb(params: List[Tensor],
- grads: List[Tensor],
- exp_avgs: List[Tensor],
- exp_avg_sqs: List[Tensor],
- state_steps: List[Tensor],
- beta1: float,
- beta2: float,
- lr: float,
- weight_decay: float,
- eps: float):
-
+def lamb(
+ params: List[Tensor],
+ grads: List[Tensor],
+ exp_avgs: List[Tensor],
+ exp_avg_sqs: List[Tensor],
+ state_steps: List[Tensor],
+ beta1: float,
+ beta2: float,
+ lr: float,
+ weight_decay: float,
+ eps: float,
+):
if not all(isinstance(t, torch.Tensor) for t in state_steps):
raise RuntimeError(
- 'API has changed, `state_steps` argument must contain a list of' +
- ' singleton tensors')
+ 'API has changed, `state_steps` argument must contain a list of'
+ + ' singleton tensors'
+ )
for i, param in enumerate(params):
grad = grads[i]
@@ -147,61 +149,67 @@ def lamb(params: List[Tensor],
update_norm = torch.linalg.norm(update)
# Set trust_ratio to 1 in case where parameters would never be updated.
- if param_norm == 0. or update_norm == 0.:
- trust_ratio = 1.
+ if param_norm == 0.0 or update_norm == 0.0:
+ trust_ratio = 1.0
else:
trust_ratio = param_norm / update_norm
param.add_(update, alpha=-lr * trust_ratio)
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a LAMB optimizer and a learning rate schedule."""
del model_state
del rng
optimizer_state = {
- 'optimizer':
- LAMB(
- model_params.parameters(),
- lr=hyperparameters.learning_rate,
- betas=(hyperparameters.beta1, hyperparameters.beta2),
- eps=1e-8,
- weight_decay=hyperparameters.weight_decay)
+ 'optimizer': LAMB(
+ model_params.parameters(),
+ lr=hyperparameters.learning_rate,
+ betas=(hyperparameters.beta1, hyperparameters.beta2),
+ eps=1e-8,
+ weight_decay=hyperparameters.weight_decay,
+ )
}
def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer):
warmup_steps = int(hyperparameters.warmup_factor * step_hint)
warmup = LinearLR(
- optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps)
+ optimizer, start_factor=1e-10, end_factor=1.0, total_iters=warmup_steps
+ )
cosine_steps = max(step_hint - warmup_steps, 1)
cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps)
return SequentialLR(
- optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps])
+ optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]
+ )
optimizer_state['scheduler'] = pytorch_cosine_warmup(
- workload.step_hint, hyperparameters, optimizer_state['optimizer'])
+ workload.step_hint, hyperparameters, optimizer_state['optimizer']
+ )
return optimizer_state
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -213,31 +221,36 @@ def update_params(
optimizer_state['optimizer'].zero_grad()
logits_batch, new_model_state = workload.model_fn(
- params=current_model,
- augmented_and_preprocessed_input_batch=batch,
- model_state=model_state,
- mode=spec.ForwardPassMode.TRAIN,
- rng=rng,
- update_batch_norm=True)
+ params=current_model,
+ augmented_and_preprocessed_input_batch=batch,
+ model_state=model_state,
+ mode=spec.ForwardPassMode.TRAIN,
+ rng=rng,
+ update_batch_norm=True,
+ )
label_smoothing = (
- hyperparameters.label_smoothing if hasattr(hyperparameters,
- 'label_smoothing') else 0.0)
+ hyperparameters.label_smoothing
+ if hasattr(hyperparameters, 'label_smoothing')
+ else 0.0
+ )
if hasattr(hyperparameters, 'grad_clip'):
grad_clip = hyperparameters.grad_clip
else:
grad_clip = None
loss, _ = workload.loss_fn(
- label_batch=batch['targets'],
- logits_batch=logits_batch,
- mask_batch=batch.get('weights'),
- label_smoothing=label_smoothing)
+ label_batch=batch['targets'],
+ logits_batch=logits_batch,
+ mask_batch=batch.get('weights'),
+ label_smoothing=label_smoothing,
+ )
loss.backward()
if grad_clip is not None:
torch.nn.utils.clip_grad_norm_(
- current_model.parameters(), max_norm=grad_clip)
+ current_model.parameters(), max_norm=grad_clip
+ )
optimizer_state['optimizer'].step()
optimizer_state['scheduler'].step()
@@ -246,31 +259,38 @@ def update_params(
with torch.no_grad():
parameters = [p for p in current_model.parameters() if p.grad is not None]
grad_norm = torch.norm(
- torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2)
+ torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2
+ )
if workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
- {
- 'loss': loss.item(),
- 'grad_norm': grad_norm.item(),
- }, global_step)
- logging.info('%d) loss = %0.3f, grad_norm = %0.3f',
- global_step,
- loss.item(),
- grad_norm.item())
+ {
+ 'loss': loss.item(),
+ 'grad_norm': grad_norm.item(),
+ },
+ global_step,
+ )
+ logging.info(
+ '%d) loss = %0.3f, grad_norm = %0.3f',
+ global_step,
+ loss.item(),
+ grad_norm.item(),
+ )
return (optimizer_state, current_param_container, new_model_state)
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -306,14 +326,16 @@ def get_batch_size(workload_name):
raise ValueError(f'Unsupported workload name: {workload_name}.')
-def data_selection(workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+def data_selection(
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
diff --git a/reference_algorithms/paper_baselines/momentum/jax/submission.py b/reference_algorithms/paper_baselines/momentum/jax/submission.py
index cbb6d6dcd..3540f9415 100644
--- a/reference_algorithms/paper_baselines/momentum/jax/submission.py
+++ b/reference_algorithms/paper_baselines/momentum/jax/submission.py
@@ -14,11 +14,13 @@
_GRAD_CLIP_EPS = 1e-6
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a Nesterov optimizer and a learning rate schedule."""
del model_params
del model_state
@@ -28,34 +30,39 @@ def init_optimizer_state(workload: spec.Workload,
lr_schedule_fn = create_lr_schedule_fn(workload.step_hint, hyperparameters)
# Create optimizer.
- params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple),
- workload.param_shapes)
+ params_zeros_like = jax.tree.map(
+ lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes
+ )
opt_init_fn, opt_update_fn = sgd(
- learning_rate=lr_schedule_fn,
- weight_decay=hyperparameters.weight_decay,
- momentum=1.0 - hyperparameters.one_minus_beta1,
- nesterov=False)
+ learning_rate=lr_schedule_fn,
+ weight_decay=hyperparameters.weight_decay,
+ momentum=1.0 - hyperparameters.one_minus_beta1,
+ nesterov=False,
+ )
optimizer_state = opt_init_fn(params_zeros_like)
return jax_utils.replicate(optimizer_state), opt_update_fn
def create_lr_schedule_fn(
- step_hint: int,
- hyperparameters: spec.Hyperparameters) -> Callable[[int], float]:
+ step_hint: int, hyperparameters: spec.Hyperparameters
+) -> Callable[[int], float]:
warmup_steps = int(hyperparameters.warmup_factor * step_hint)
warmup_fn = optax.linear_schedule(
- init_value=0.,
- end_value=hyperparameters.learning_rate,
- transition_steps=warmup_steps)
+ init_value=0.0,
+ end_value=hyperparameters.learning_rate,
+ transition_steps=warmup_steps,
+ )
decay_steps = step_hint - warmup_steps
polynomial_schedule_fn = optax.polynomial_schedule(
- init_value=hyperparameters.learning_rate,
- end_value=hyperparameters.learning_rate * hyperparameters.end_factor,
- power=1,
- transition_steps=int(decay_steps * hyperparameters.decay_steps_factor))
+ init_value=hyperparameters.learning_rate,
+ end_value=hyperparameters.learning_rate * hyperparameters.end_factor,
+ power=1,
+ transition_steps=int(decay_steps * hyperparameters.decay_steps_factor),
+ )
lr_schedule_fn = optax.join_schedules(
- schedules=[warmup_fn, polynomial_schedule_fn], boundaries=[warmup_steps])
+ schedules=[warmup_fn, polynomial_schedule_fn], boundaries=[warmup_steps]
+ )
return lr_schedule_fn
@@ -82,81 +89,92 @@ def sgd(learning_rate, weight_decay, momentum=None, nesterov=False):
{SGD, Momentum, Nesterov} update.
"""
return optax.chain(
- optax.add_decayed_weights(weight_decay),
- optax.sgd(
- learning_rate=learning_rate, momentum=momentum, nesterov=nesterov))
+ optax.add_decayed_weights(weight_decay),
+ optax.sgd(
+ learning_rate=learning_rate, momentum=momentum, nesterov=nesterov
+ ),
+ )
@functools.partial(
- jax.pmap,
- axis_name='batch',
- in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
- static_broadcasted_argnums=(0, 1),
- donate_argnums=(2, 3, 4))
-def pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- batch,
- rng,
- grad_clip,
- label_smoothing):
-
+ jax.pmap,
+ axis_name='batch',
+ in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
+ static_broadcasted_argnums=(0, 1),
+ donate_argnums=(2, 3, 4),
+)
+def pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ rng,
+ grad_clip,
+ label_smoothing,
+):
def _loss_fn(params):
"""Loss function used for training."""
logits, new_model_state = workload.model_fn(
- params,
- batch,
- model_state,
- spec.ForwardPassMode.TRAIN,
- rng,
- update_batch_norm=True)
+ params,
+ batch,
+ model_state,
+ spec.ForwardPassMode.TRAIN,
+ rng,
+ update_batch_norm=True,
+ )
loss_dict = workload.loss_fn(
- label_batch=batch['targets'],
- logits_batch=logits,
- mask_batch=batch.get('weights'),
- label_smoothing=label_smoothing)
+ label_batch=batch['targets'],
+ logits_batch=logits,
+ mask_batch=batch.get('weights'),
+ label_smoothing=label_smoothing,
+ )
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
return summed_loss, (n_valid_examples, new_model_state)
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
(summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn(
- current_param_container)
+ current_param_container
+ )
# Get correct global mean loss and grad.
(summed_loss, n_valid_examples, grad) = lax.psum(
- (summed_loss, n_valid_examples, grad), axis_name='batch')
+ (summed_loss, n_valid_examples, grad), axis_name='batch'
+ )
loss = summed_loss / n_valid_examples
grad = jax.tree.map(lambda x: x / n_valid_examples, grad)
grad_norm = jnp.sqrt(
- sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)))
+ sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))
+ )
if grad_clip is not None:
grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS)
grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0)
grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad)
- updates, new_optimizer_state = opt_update_fn(grad, optimizer_state,
- current_param_container)
+ updates, new_optimizer_state = opt_update_fn(
+ grad, optimizer_state, current_param_container
+ )
updated_params = optax.apply_updates(current_param_container, updates)
return new_optimizer_state, updated_params, new_model_state, loss, grad_norm
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -173,37 +191,43 @@ def update_params(
grad_clip = hyperparameters.grad_clip
else:
grad_clip = None
- outputs = pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- batch,
- per_device_rngs,
- grad_clip,
- label_smoothing)
+ outputs = pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ per_device_rngs,
+ grad_clip,
+ label_smoothing,
+ )
new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs
# Log loss, grad_norm.
if global_step % 100 == 0 and workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
- {
- 'loss': loss[0],
- 'grad_norm': grad_norm[0],
- }, global_step)
+ {
+ 'loss': loss[0],
+ 'grad_norm': grad_norm[0],
+ },
+ global_step,
+ )
return (new_optimizer_state, opt_update_fn), new_params, new_model_state
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -243,14 +267,16 @@ def get_batch_size(workload_name):
raise ValueError(f'Unsupported workload name: {workload_name}.')
-def data_selection(workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+def data_selection(
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
diff --git a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py
index c3760d20e..bad750857 100644
--- a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py
+++ b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py
@@ -14,24 +14,26 @@
USE_PYTORCH_DDP = pytorch_setup()[0]
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a Nesterov optimizer and a learning rate schedule."""
del model_state
del rng
# Create optimizer.
optimizer_state = {
- 'optimizer':
- torch.optim.SGD(
- model_params.parameters(),
- lr=hyperparameters.learning_rate,
- momentum=1.0 - hyperparameters.one_minus_beta1,
- weight_decay=hyperparameters.weight_decay,
- nesterov=False),
+ 'optimizer': torch.optim.SGD(
+ model_params.parameters(),
+ lr=hyperparameters.learning_rate,
+ momentum=1.0 - hyperparameters.one_minus_beta1,
+ weight_decay=hyperparameters.weight_decay,
+ nesterov=False,
+ ),
}
# Create learning rate schedule.
@@ -43,43 +45,48 @@ def _lr_lambda(step: int) -> float:
return lr_schedule_fn(step).item() / hyperparameters.learning_rate
optimizer_state['scheduler'] = LambdaLR(
- optimizer_state['optimizer'], lr_lambda=_lr_lambda)
+ optimizer_state['optimizer'], lr_lambda=_lr_lambda
+ )
return optimizer_state
def create_lr_schedule_fn(
- step_hint: int,
- hyperparameters: spec.Hyperparameters) -> Callable[[int], float]:
+ step_hint: int, hyperparameters: spec.Hyperparameters
+) -> Callable[[int], float]:
warmup_steps = int(hyperparameters.warmup_factor * step_hint)
warmup_fn = optax.linear_schedule(
- init_value=0.,
- end_value=hyperparameters.learning_rate,
- transition_steps=warmup_steps)
+ init_value=0.0,
+ end_value=hyperparameters.learning_rate,
+ transition_steps=warmup_steps,
+ )
decay_steps = step_hint - warmup_steps
polynomial_schedule_fn = optax.polynomial_schedule(
- init_value=hyperparameters.learning_rate,
- end_value=hyperparameters.learning_rate * hyperparameters.end_factor,
- power=1,
- transition_steps=int(decay_steps * hyperparameters.decay_steps_factor))
+ init_value=hyperparameters.learning_rate,
+ end_value=hyperparameters.learning_rate * hyperparameters.end_factor,
+ power=1,
+ transition_steps=int(decay_steps * hyperparameters.decay_steps_factor),
+ )
lr_schedule_fn = optax.join_schedules(
- schedules=[warmup_fn, polynomial_schedule_fn], boundaries=[warmup_steps])
+ schedules=[warmup_fn, polynomial_schedule_fn], boundaries=[warmup_steps]
+ )
return lr_schedule_fn
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -91,26 +98,30 @@ def update_params(
optimizer_state['optimizer'].zero_grad()
logits_batch, new_model_state = workload.model_fn(
- params=current_model,
- augmented_and_preprocessed_input_batch=batch,
- model_state=model_state,
- mode=spec.ForwardPassMode.TRAIN,
- rng=rng,
- update_batch_norm=True)
+ params=current_model,
+ augmented_and_preprocessed_input_batch=batch,
+ model_state=model_state,
+ mode=spec.ForwardPassMode.TRAIN,
+ rng=rng,
+ update_batch_norm=True,
+ )
label_smoothing = (
- hyperparameters.label_smoothing if hasattr(hyperparameters,
- 'label_smoothing') else 0.0)
+ hyperparameters.label_smoothing
+ if hasattr(hyperparameters, 'label_smoothing')
+ else 0.0
+ )
if hasattr(hyperparameters, 'grad_clip'):
grad_clip = hyperparameters.grad_clip
else:
grad_clip = None
loss_dict = workload.loss_fn(
- label_batch=batch['targets'],
- logits_batch=logits_batch,
- mask_batch=batch.get('weights'),
- label_smoothing=label_smoothing)
+ label_batch=batch['targets'],
+ logits_batch=logits_batch,
+ mask_batch=batch.get('weights'),
+ label_smoothing=label_smoothing,
+ )
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
if USE_PYTORCH_DDP:
@@ -123,7 +134,8 @@ def update_params(
if grad_clip is not None:
torch.nn.utils.clip_grad_norm_(
- current_model.parameters(), max_norm=grad_clip)
+ current_model.parameters(), max_norm=grad_clip
+ )
optimizer_state['optimizer'].step()
optimizer_state['scheduler'].step()
@@ -132,31 +144,38 @@ def update_params(
with torch.no_grad():
parameters = [p for p in current_model.parameters() if p.grad is not None]
grad_norm = torch.norm(
- torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2)
+ torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2
+ )
if workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
- {
- 'loss': loss.item(),
- 'grad_norm': grad_norm.item(),
- }, global_step)
- logging.info('%d) loss = %0.3f, grad_norm = %0.3f',
- global_step,
- loss.item(),
- grad_norm.item())
+ {
+ 'loss': loss.item(),
+ 'grad_norm': grad_norm.item(),
+ },
+ global_step,
+ )
+ logging.info(
+ '%d) loss = %0.3f, grad_norm = %0.3f',
+ global_step,
+ loss.item(),
+ grad_norm.item(),
+ )
return (optimizer_state, current_param_container, new_model_state)
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -196,14 +215,16 @@ def get_batch_size(workload_name):
raise ValueError(f'Unsupported workload name: {workload_name}.')
-def data_selection(workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+def data_selection(
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
diff --git a/reference_algorithms/paper_baselines/nadamw/jax/submission.py b/reference_algorithms/paper_baselines/nadamw/jax/submission.py
index c451a18ac..aa1a08f69 100644
--- a/reference_algorithms/paper_baselines/nadamw/jax/submission.py
+++ b/reference_algorithms/paper_baselines/nadamw/jax/submission.py
@@ -4,15 +4,17 @@
# isort: off
# We have to turn off isort here to resolve a conflict between isort and yapf.
-from typing import (Any,
- Callable,
- Dict,
- Iterator,
- List,
- NamedTuple,
- Optional,
- Tuple,
- Union)
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Iterator,
+ List,
+ NamedTuple,
+ Optional,
+ Tuple,
+ Union,
+)
# isort: on
import chex
@@ -30,15 +32,14 @@
# Forked from
# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py
def nadamw(
- learning_rate: Union[float, optax.Schedule],
- b1: float = 0.9,
- b2: float = 0.999,
- eps: float = 1e-8,
- eps_root: float = 0.0,
- debias: bool = True,
- weight_decay: float = 0.0,
- weight_decay_mask: Optional[Union[Any, Callable[[optax.Params],
- Any]]] = None,
+ learning_rate: Union[float, optax.Schedule],
+ b1: float = 0.9,
+ b2: float = 0.999,
+ eps: float = 1e-8,
+ eps_root: float = 0.0,
+ debias: bool = True,
+ weight_decay: float = 0.0,
+ weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], Any]]] = None,
) -> optax.GradientTransformation:
"""Rescale updates according to the NAdam algorithm.
@@ -73,19 +74,22 @@ def nadamw(
An (init_fn, update_fn) tuple.
"""
return optax.chain(
- scale_by_nadam(b1, b2, eps, eps_root, debias),
- optax.add_decayed_weights(weight_decay, weight_decay_mask),
- scale_by_learning_rate(learning_rate))
+ scale_by_nadam(b1, b2, eps, eps_root, debias),
+ optax.add_decayed_weights(weight_decay, weight_decay_mask),
+ scale_by_learning_rate(learning_rate),
+ )
# All functions below are forked from
# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py
-def scale_by_nadam(b1: float = 0.9,
- b2: float = 0.999,
- eps: float = 1e-8,
- eps_root: float = 0.0,
- debias: bool = True,
- power: float = 0.5) -> optax.GradientTransformation:
+def scale_by_nadam(
+ b1: float = 0.9,
+ b2: float = 0.999,
+ eps: float = 1e-8,
+ eps_root: float = 0.0,
+ debias: bool = True,
+ power: float = 0.5,
+) -> optax.GradientTransformation:
"""Rescale updates according to the NAdam algorithm.
References:
@@ -124,7 +128,8 @@ def update_fn(updates, state, params=None):
mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count)
nu_hat = nu if not debias else _bias_correction(nu, b2, count)
updates = jax.tree.map(
- lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat)
+ lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat
+ )
return updates, ScaleByAdamState(count=count, mu=mu, nu=nu)
return optax.GradientTransformation(init_fn, update_fn)
@@ -132,6 +137,7 @@ def update_fn(updates, state, params=None):
class ScaleByAdamState(NamedTuple):
"""State for the NAdam algorithm."""
+
count: chex.Array # shape=(), dtype=jnp.int32.
mu: optax.Updates
nu: optax.Updates
@@ -140,7 +146,8 @@ class ScaleByAdamState(NamedTuple):
def _update_moment(updates, moments, decay, order):
"""Compute the exponential moving average of the `order-th` moment."""
return jax.tree.map(
- lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments)
+ lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments
+ )
def _bias_correction(moment, decay, count):
@@ -156,11 +163,13 @@ def scale_by_learning_rate(learning_rate, flip_sign=True):
return optax.scale(m * learning_rate)
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a NAdamW optimizer and a learning rate schedule."""
del model_params
del model_state
@@ -170,101 +179,115 @@ def jax_cosine_warmup(step_hint: int, hyperparameters):
# Create learning rate schedule.
warmup_steps = int(hyperparameters.warmup_factor * step_hint)
warmup_fn = optax.linear_schedule(
- init_value=0.,
- end_value=hyperparameters.learning_rate,
- transition_steps=warmup_steps)
+ init_value=0.0,
+ end_value=hyperparameters.learning_rate,
+ transition_steps=warmup_steps,
+ )
cosine_steps = max(step_hint - warmup_steps, 1)
cosine_fn = optax.cosine_decay_schedule(
- init_value=hyperparameters.learning_rate, decay_steps=cosine_steps)
+ init_value=hyperparameters.learning_rate, decay_steps=cosine_steps
+ )
schedule_fn = optax.join_schedules(
- schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps])
+ schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]
+ )
return schedule_fn
# Create optimizer + LR schedule.
lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters)
opt_init_fn, opt_update_fn = nadamw(
- learning_rate=lr_schedule_fn,
- b1=1.0 - hyperparameters.one_minus_beta1,
- b2=hyperparameters.beta2,
- eps=1e-8,
- weight_decay=hyperparameters.weight_decay)
- params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple),
- workload.param_shapes)
+ learning_rate=lr_schedule_fn,
+ b1=1.0 - hyperparameters.one_minus_beta1,
+ b2=hyperparameters.beta2,
+ eps=1e-8,
+ weight_decay=hyperparameters.weight_decay,
+ )
+ params_zeros_like = jax.tree.map(
+ lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes
+ )
optimizer_state = opt_init_fn(params_zeros_like)
return jax_utils.replicate(optimizer_state), opt_update_fn
@functools.partial(
- jax.pmap,
- axis_name='batch',
- in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
- static_broadcasted_argnums=(0, 1),
- donate_argnums=(2, 3, 4))
-def pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- batch,
- rng,
- grad_clip,
- label_smoothing):
-
+ jax.pmap,
+ axis_name='batch',
+ in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
+ static_broadcasted_argnums=(0, 1),
+ donate_argnums=(2, 3, 4),
+)
+def pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ rng,
+ grad_clip,
+ label_smoothing,
+):
def _loss_fn(params):
"""Loss function used for training."""
logits, new_model_state = workload.model_fn(
- params,
- batch,
- model_state,
- spec.ForwardPassMode.TRAIN,
- rng,
- update_batch_norm=True)
+ params,
+ batch,
+ model_state,
+ spec.ForwardPassMode.TRAIN,
+ rng,
+ update_batch_norm=True,
+ )
loss_dict = workload.loss_fn(
- label_batch=batch['targets'],
- logits_batch=logits,
- mask_batch=batch.get('weights'),
- label_smoothing=label_smoothing)
+ label_batch=batch['targets'],
+ logits_batch=logits,
+ mask_batch=batch.get('weights'),
+ label_smoothing=label_smoothing,
+ )
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
return summed_loss, (n_valid_examples, new_model_state)
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
(summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn(
- current_param_container)
+ current_param_container
+ )
# Get correct global mean loss and grad.
(summed_loss, n_valid_examples, grad) = lax.psum(
- (summed_loss, n_valid_examples, grad), axis_name='batch')
+ (summed_loss, n_valid_examples, grad), axis_name='batch'
+ )
loss = summed_loss / n_valid_examples
grad = jax.tree.map(lambda x: x / n_valid_examples, grad)
grad_norm = jnp.sqrt(
- sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)))
+ sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))
+ )
if grad_clip is not None:
grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS)
grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0)
grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad)
- updates, new_optimizer_state = opt_update_fn(grad, optimizer_state,
- current_param_container)
+ updates, new_optimizer_state = opt_update_fn(
+ grad, optimizer_state, current_param_container
+ )
updated_params = optax.apply_updates(current_param_container, updates)
return new_optimizer_state, updated_params, new_model_state, loss, grad_norm
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -281,37 +304,43 @@ def update_params(
grad_clip = hyperparameters.grad_clip
else:
grad_clip = None
- outputs = pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- batch,
- per_device_rngs,
- grad_clip,
- label_smoothing)
+ outputs = pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ per_device_rngs,
+ grad_clip,
+ label_smoothing,
+ )
new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs
# Log loss, grad_norm.
if global_step % 100 == 0 and workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
- {
- 'loss': loss[0],
- 'grad_norm': grad_norm[0],
- }, global_step)
+ {
+ 'loss': loss[0],
+ 'grad_norm': grad_norm[0],
+ },
+ global_step,
+ )
return (new_optimizer_state, opt_update_fn), new_params, new_model_state
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -351,14 +380,16 @@ def get_batch_size(workload_name):
raise ValueError(f'Unsupported workload name: {workload_name}.')
-def data_selection(workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+def data_selection(
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
diff --git a/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py b/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py
index a2f9fb4c5..ecd299988 100644
--- a/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py
+++ b/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py
@@ -21,33 +21,30 @@
class NAdamW(torch.optim.Optimizer):
r"""Implements NAdamW algorithm.
- See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of
- the NAdam algorithm (there is also a comment in the code which highlights
- the only difference of NAdamW and AdamW).
- For further details regarding the algorithm we refer to
- `Decoupled Weight Decay Regularization`_.
-
- Args:
- params (iterable): iterable of parameters to optimize or dicts defining
- parameter groups
- lr (float, optional): learning rate (default: 1e-3)
- betas (Tuple[float, float], optional): coefficients used for computing
- running averages of gradient and its square (default: (0.9, 0.999))
- eps (float, optional): term added to the denominator to improve
- numerical stability (default: 1e-8)
- weight_decay (float, optional): weight decay coefficient (default: 1e-2)
- .. _Decoupled Weight Decay Regularization:
- https://arxiv.org/abs/1711.05101
- .. _On the Convergence of Adam and Beyond:
- https://openreview.net/forum?id=ryQu7f-RZ
+ See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of
+ the NAdam algorithm (there is also a comment in the code which highlights
+ the only difference of NAdamW and AdamW).
+ For further details regarding the algorithm we refer to
+ `Decoupled Weight Decay Regularization`_.
+
+ Args:
+ params (iterable): iterable of parameters to optimize or dicts defining
+ parameter groups
+ lr (float, optional): learning rate (default: 1e-3)
+ betas (Tuple[float, float], optional): coefficients used for computing
+ running averages of gradient and its square (default: (0.9, 0.999))
+ eps (float, optional): term added to the denominator to improve
+ numerical stability (default: 1e-8)
+ weight_decay (float, optional): weight decay coefficient (default: 1e-2)
+ .. _Decoupled Weight Decay Regularization:
+ https://arxiv.org/abs/1711.05101
+ .. _On the Convergence of Adam and Beyond:
+ https://openreview.net/forum?id=ryQu7f-RZ
"""
- def __init__(self,
- params,
- lr=1e-3,
- betas=(0.9, 0.999),
- eps=1e-8,
- weight_decay=1e-2):
+ def __init__(
+ self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2
+ ):
if not 0.0 <= lr:
raise ValueError(f'Invalid learning rate: {lr}')
if not 0.0 <= eps:
@@ -59,7 +56,10 @@ def __init__(self,
if not 0.0 <= weight_decay:
raise ValueError(f'Invalid weight_decay value: {weight_decay}')
defaults = {
- 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay
+ 'lr': lr,
+ 'betas': betas,
+ 'eps': eps,
+ 'weight_decay': weight_decay,
}
super().__init__(params, defaults)
@@ -67,7 +67,8 @@ def __setstate__(self, state):
super().__setstate__(state)
state_values = list(self.state.values())
step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
- state_values[0]['step'])
+ state_values[0]['step']
+ )
if not step_is_tensor:
for s in state_values:
s['step'] = torch.tensor(float(s['step']))
@@ -76,9 +77,9 @@ def __setstate__(self, state):
def step(self, closure=None):
"""Performs a single optimization step.
- Args:
- closure (callable, optional): A closure that reevaluates the model
- and returns the loss.
+ Args:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
"""
self._cuda_graph_capture_health_check()
@@ -107,51 +108,57 @@ def step(self, closure=None):
# State initialization
if len(state) == 0:
- state['step'] = torch.tensor(0.)
+ state['step'] = torch.tensor(0.0)
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(
- p, memory_format=torch.preserve_format)
+ p, memory_format=torch.preserve_format
+ )
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(
- p, memory_format=torch.preserve_format)
+ p, memory_format=torch.preserve_format
+ )
exp_avgs.append(state['exp_avg'])
exp_avg_sqs.append(state['exp_avg_sq'])
state_steps.append(state['step'])
nadamw(
- params_with_grad,
- grads,
- exp_avgs,
- exp_avg_sqs,
- state_steps,
- beta1=beta1,
- beta2=beta2,
- lr=group['lr'],
- weight_decay=group['weight_decay'],
- eps=group['eps'])
+ params_with_grad,
+ grads,
+ exp_avgs,
+ exp_avg_sqs,
+ state_steps,
+ beta1=beta1,
+ beta2=beta2,
+ lr=group['lr'],
+ weight_decay=group['weight_decay'],
+ eps=group['eps'],
+ )
return loss
-def nadamw(params: List[Tensor],
- grads: List[Tensor],
- exp_avgs: List[Tensor],
- exp_avg_sqs: List[Tensor],
- state_steps: List[Tensor],
- beta1: float,
- beta2: float,
- lr: float,
- weight_decay: float,
- eps: float) -> None:
+def nadamw(
+ params: List[Tensor],
+ grads: List[Tensor],
+ exp_avgs: List[Tensor],
+ exp_avg_sqs: List[Tensor],
+ state_steps: List[Tensor],
+ beta1: float,
+ beta2: float,
+ lr: float,
+ weight_decay: float,
+ eps: float,
+) -> None:
r"""Functional API that performs NAdamW algorithm computation.
- See NAdamW class for details.
+ See NAdamW class for details.
"""
if not all(isinstance(t, torch.Tensor) for t in state_steps):
raise RuntimeError(
- 'API has changed, `state_steps` argument must contain a list of' +
- ' singleton tensors')
+ 'API has changed, `state_steps` argument must contain a list of'
+ + ' singleton tensors'
+ )
for i, param in enumerate(params):
grad = grads[i]
@@ -189,54 +196,59 @@ def nadamw(params: List[Tensor],
exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1)
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a NAdamW optimizer and a learning rate schedule."""
del model_state
del rng
optimizer_state = {
- 'optimizer':
- NAdamW(
- model_params.parameters(),
- lr=hyperparameters.learning_rate,
- betas=(1.0 - hyperparameters.one_minus_beta1,
- hyperparameters.beta2),
- eps=1e-8,
- weight_decay=hyperparameters.weight_decay),
+ 'optimizer': NAdamW(
+ model_params.parameters(),
+ lr=hyperparameters.learning_rate,
+ betas=(1.0 - hyperparameters.one_minus_beta1, hyperparameters.beta2),
+ eps=1e-8,
+ weight_decay=hyperparameters.weight_decay,
+ ),
}
def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer):
warmup_steps = int(hyperparameters.warmup_factor * step_hint)
warmup = LinearLR(
- optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps)
+ optimizer, start_factor=1e-10, end_factor=1.0, total_iters=warmup_steps
+ )
cosine_steps = max(step_hint - warmup_steps, 1)
cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps)
return SequentialLR(
- optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps])
+ optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]
+ )
optimizer_state['scheduler'] = pytorch_cosine_warmup(
- workload.step_hint, hyperparameters, optimizer_state['optimizer'])
+ workload.step_hint, hyperparameters, optimizer_state['optimizer']
+ )
return optimizer_state
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -248,26 +260,30 @@ def update_params(
optimizer_state['optimizer'].zero_grad()
logits_batch, new_model_state = workload.model_fn(
- params=current_model,
- augmented_and_preprocessed_input_batch=batch,
- model_state=model_state,
- mode=spec.ForwardPassMode.TRAIN,
- rng=rng,
- update_batch_norm=True)
+ params=current_model,
+ augmented_and_preprocessed_input_batch=batch,
+ model_state=model_state,
+ mode=spec.ForwardPassMode.TRAIN,
+ rng=rng,
+ update_batch_norm=True,
+ )
label_smoothing = (
- hyperparameters.label_smoothing if hasattr(hyperparameters,
- 'label_smoothing') else 0.0)
+ hyperparameters.label_smoothing
+ if hasattr(hyperparameters, 'label_smoothing')
+ else 0.0
+ )
if hasattr(hyperparameters, 'grad_clip'):
grad_clip = hyperparameters.grad_clip
else:
grad_clip = None
loss_dict = workload.loss_fn(
- label_batch=batch['targets'],
- logits_batch=logits_batch,
- mask_batch=batch.get('weights'),
- label_smoothing=label_smoothing)
+ label_batch=batch['targets'],
+ logits_batch=logits_batch,
+ mask_batch=batch.get('weights'),
+ label_smoothing=label_smoothing,
+ )
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
if USE_PYTORCH_DDP:
@@ -280,7 +296,8 @@ def update_params(
if grad_clip is not None:
torch.nn.utils.clip_grad_norm_(
- current_model.parameters(), max_norm=grad_clip)
+ current_model.parameters(), max_norm=grad_clip
+ )
optimizer_state['optimizer'].step()
optimizer_state['scheduler'].step()
@@ -289,31 +306,38 @@ def update_params(
with torch.no_grad():
parameters = [p for p in current_model.parameters() if p.grad is not None]
grad_norm = torch.norm(
- torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2)
+ torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2
+ )
if workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
- {
- 'loss': loss.item(),
- 'grad_norm': grad_norm.item(),
- }, global_step)
- logging.info('%d) loss = %0.3f, grad_norm = %0.3f',
- global_step,
- loss.item(),
- grad_norm.item())
+ {
+ 'loss': loss.item(),
+ 'grad_norm': grad_norm.item(),
+ },
+ global_step,
+ )
+ logging.info(
+ '%d) loss = %0.3f, grad_norm = %0.3f',
+ global_step,
+ loss.item(),
+ grad_norm.item(),
+ )
return (optimizer_state, current_param_container, new_model_state)
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -353,14 +377,16 @@ def get_batch_size(workload_name):
raise ValueError(f'Unsupported workload name: {workload_name}.')
-def data_selection(workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+def data_selection(
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py
index 0e53aae42..d32026212 100644
--- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py
+++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py
@@ -14,11 +14,13 @@
_GRAD_CLIP_EPS = 1e-6
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a Nesterov optimizer and a learning rate schedule."""
del model_params
del model_state
@@ -28,34 +30,39 @@ def init_optimizer_state(workload: spec.Workload,
lr_schedule_fn = create_lr_schedule_fn(workload.step_hint, hyperparameters)
# Create optimizer.
- params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple),
- workload.param_shapes)
+ params_zeros_like = jax.tree.map(
+ lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes
+ )
opt_init_fn, opt_update_fn = sgd(
- learning_rate=lr_schedule_fn,
- weight_decay=hyperparameters.weight_decay,
- momentum=1.0 - hyperparameters.one_minus_beta1,
- nesterov=True)
+ learning_rate=lr_schedule_fn,
+ weight_decay=hyperparameters.weight_decay,
+ momentum=1.0 - hyperparameters.one_minus_beta1,
+ nesterov=True,
+ )
optimizer_state = opt_init_fn(params_zeros_like)
return jax_utils.replicate(optimizer_state), opt_update_fn
def create_lr_schedule_fn(
- step_hint: int,
- hyperparameters: spec.Hyperparameters) -> Callable[[int], float]:
+ step_hint: int, hyperparameters: spec.Hyperparameters
+) -> Callable[[int], float]:
warmup_steps = int(hyperparameters.warmup_factor * step_hint)
warmup_fn = optax.linear_schedule(
- init_value=0.,
- end_value=hyperparameters.learning_rate,
- transition_steps=warmup_steps)
+ init_value=0.0,
+ end_value=hyperparameters.learning_rate,
+ transition_steps=warmup_steps,
+ )
decay_steps = step_hint - warmup_steps
polynomial_schedule_fn = optax.polynomial_schedule(
- init_value=hyperparameters.learning_rate,
- end_value=hyperparameters.learning_rate * hyperparameters.end_factor,
- power=1,
- transition_steps=int(decay_steps * hyperparameters.decay_steps_factor))
+ init_value=hyperparameters.learning_rate,
+ end_value=hyperparameters.learning_rate * hyperparameters.end_factor,
+ power=1,
+ transition_steps=int(decay_steps * hyperparameters.decay_steps_factor),
+ )
lr_schedule_fn = optax.join_schedules(
- schedules=[warmup_fn, polynomial_schedule_fn], boundaries=[warmup_steps])
+ schedules=[warmup_fn, polynomial_schedule_fn], boundaries=[warmup_steps]
+ )
return lr_schedule_fn
@@ -82,81 +89,92 @@ def sgd(learning_rate, weight_decay, momentum=None, nesterov=False):
{SGD, Momentum, Nesterov} update.
"""
return optax.chain(
- optax.add_decayed_weights(weight_decay),
- optax.sgd(
- learning_rate=learning_rate, momentum=momentum, nesterov=nesterov))
+ optax.add_decayed_weights(weight_decay),
+ optax.sgd(
+ learning_rate=learning_rate, momentum=momentum, nesterov=nesterov
+ ),
+ )
@functools.partial(
- jax.pmap,
- axis_name='batch',
- in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
- static_broadcasted_argnums=(0, 1),
- donate_argnums=(2, 3, 4))
-def pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- batch,
- rng,
- grad_clip,
- label_smoothing):
-
+ jax.pmap,
+ axis_name='batch',
+ in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
+ static_broadcasted_argnums=(0, 1),
+ donate_argnums=(2, 3, 4),
+)
+def pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ rng,
+ grad_clip,
+ label_smoothing,
+):
def _loss_fn(params):
"""Loss function used for training."""
logits, new_model_state = workload.model_fn(
- params,
- batch,
- model_state,
- spec.ForwardPassMode.TRAIN,
- rng,
- update_batch_norm=True)
+ params,
+ batch,
+ model_state,
+ spec.ForwardPassMode.TRAIN,
+ rng,
+ update_batch_norm=True,
+ )
loss_dict = workload.loss_fn(
- label_batch=batch['targets'],
- logits_batch=logits,
- mask_batch=batch.get('weights'),
- label_smoothing=label_smoothing)
+ label_batch=batch['targets'],
+ logits_batch=logits,
+ mask_batch=batch.get('weights'),
+ label_smoothing=label_smoothing,
+ )
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
return summed_loss, (n_valid_examples, new_model_state)
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
(summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn(
- current_param_container)
+ current_param_container
+ )
# Get correct global mean loss and grad.
(summed_loss, n_valid_examples, grad) = lax.psum(
- (summed_loss, n_valid_examples, grad), axis_name='batch')
+ (summed_loss, n_valid_examples, grad), axis_name='batch'
+ )
loss = summed_loss / n_valid_examples
grad = jax.tree.map(lambda x: x / n_valid_examples, grad)
grad_norm = jnp.sqrt(
- sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)))
+ sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))
+ )
if grad_clip is not None:
grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS)
grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0)
grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad)
- updates, new_optimizer_state = opt_update_fn(grad, optimizer_state,
- current_param_container)
+ updates, new_optimizer_state = opt_update_fn(
+ grad, optimizer_state, current_param_container
+ )
updated_params = optax.apply_updates(current_param_container, updates)
return new_optimizer_state, updated_params, new_model_state, loss, grad_norm
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -173,37 +191,43 @@ def update_params(
grad_clip = hyperparameters.grad_clip
else:
grad_clip = None
- outputs = pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- batch,
- per_device_rngs,
- grad_clip,
- label_smoothing)
+ outputs = pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ per_device_rngs,
+ grad_clip,
+ label_smoothing,
+ )
new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs
# Log loss, grad_norm.
if global_step % 100 == 0 and workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
- {
- 'loss': loss[0],
- 'grad_norm': grad_norm[0],
- }, global_step)
+ {
+ 'loss': loss[0],
+ 'grad_norm': grad_norm[0],
+ },
+ global_step,
+ )
return (new_optimizer_state, opt_update_fn), new_params, new_model_state
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -243,14 +267,16 @@ def get_batch_size(workload_name):
raise ValueError(f'Unsupported workload name: {workload_name}.')
-def data_selection(workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+def data_selection(
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
diff --git a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py
index b4432fbff..77361f472 100644
--- a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py
+++ b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py
@@ -14,24 +14,26 @@
USE_PYTORCH_DDP = pytorch_setup()[0]
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a Nesterov optimizer and a learning rate schedule."""
del model_state
del rng
# Create optimizer.
optimizer_state = {
- 'optimizer':
- torch.optim.SGD(
- model_params.parameters(),
- lr=hyperparameters.learning_rate,
- momentum=1.0 - hyperparameters.one_minus_beta1,
- weight_decay=hyperparameters.weight_decay,
- nesterov=True),
+ 'optimizer': torch.optim.SGD(
+ model_params.parameters(),
+ lr=hyperparameters.learning_rate,
+ momentum=1.0 - hyperparameters.one_minus_beta1,
+ weight_decay=hyperparameters.weight_decay,
+ nesterov=True,
+ ),
}
# Create learning rate schedule.
@@ -43,43 +45,48 @@ def _lr_lambda(step: int) -> float:
return lr_schedule_fn(step).item() / hyperparameters.learning_rate
optimizer_state['scheduler'] = LambdaLR(
- optimizer_state['optimizer'], lr_lambda=_lr_lambda)
+ optimizer_state['optimizer'], lr_lambda=_lr_lambda
+ )
return optimizer_state
def create_lr_schedule_fn(
- step_hint: int,
- hyperparameters: spec.Hyperparameters) -> Callable[[int], float]:
+ step_hint: int, hyperparameters: spec.Hyperparameters
+) -> Callable[[int], float]:
warmup_steps = int(hyperparameters.warmup_factor * step_hint)
warmup_fn = optax.linear_schedule(
- init_value=0.,
- end_value=hyperparameters.learning_rate,
- transition_steps=warmup_steps)
+ init_value=0.0,
+ end_value=hyperparameters.learning_rate,
+ transition_steps=warmup_steps,
+ )
decay_steps = step_hint - warmup_steps
polynomial_schedule_fn = optax.polynomial_schedule(
- init_value=hyperparameters.learning_rate,
- end_value=hyperparameters.learning_rate * hyperparameters.end_factor,
- power=1,
- transition_steps=int(decay_steps * hyperparameters.decay_steps_factor))
+ init_value=hyperparameters.learning_rate,
+ end_value=hyperparameters.learning_rate * hyperparameters.end_factor,
+ power=1,
+ transition_steps=int(decay_steps * hyperparameters.decay_steps_factor),
+ )
lr_schedule_fn = optax.join_schedules(
- schedules=[warmup_fn, polynomial_schedule_fn], boundaries=[warmup_steps])
+ schedules=[warmup_fn, polynomial_schedule_fn], boundaries=[warmup_steps]
+ )
return lr_schedule_fn
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -91,26 +98,30 @@ def update_params(
optimizer_state['optimizer'].zero_grad()
logits_batch, new_model_state = workload.model_fn(
- params=current_model,
- augmented_and_preprocessed_input_batch=batch,
- model_state=model_state,
- mode=spec.ForwardPassMode.TRAIN,
- rng=rng,
- update_batch_norm=True)
+ params=current_model,
+ augmented_and_preprocessed_input_batch=batch,
+ model_state=model_state,
+ mode=spec.ForwardPassMode.TRAIN,
+ rng=rng,
+ update_batch_norm=True,
+ )
label_smoothing = (
- hyperparameters.label_smoothing if hasattr(hyperparameters,
- 'label_smoothing') else 0.0)
+ hyperparameters.label_smoothing
+ if hasattr(hyperparameters, 'label_smoothing')
+ else 0.0
+ )
if hasattr(hyperparameters, 'grad_clip'):
grad_clip = hyperparameters.grad_clip
else:
grad_clip = None
loss_dict = workload.loss_fn(
- label_batch=batch['targets'],
- logits_batch=logits_batch,
- mask_batch=batch.get('weights'),
- label_smoothing=label_smoothing)
+ label_batch=batch['targets'],
+ logits_batch=logits_batch,
+ mask_batch=batch.get('weights'),
+ label_smoothing=label_smoothing,
+ )
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
if USE_PYTORCH_DDP:
@@ -123,7 +134,8 @@ def update_params(
if grad_clip is not None:
torch.nn.utils.clip_grad_norm_(
- current_model.parameters(), max_norm=grad_clip)
+ current_model.parameters(), max_norm=grad_clip
+ )
optimizer_state['optimizer'].step()
optimizer_state['scheduler'].step()
@@ -132,31 +144,38 @@ def update_params(
with torch.no_grad():
parameters = [p for p in current_model.parameters() if p.grad is not None]
grad_norm = torch.norm(
- torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2)
+ torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2
+ )
if workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
- {
- 'loss': loss.item(),
- 'grad_norm': grad_norm.item(),
- }, global_step)
- logging.info('%d) loss = %0.3f, grad_norm = %0.3f',
- global_step,
- loss.item(),
- grad_norm.item())
+ {
+ 'loss': loss.item(),
+ 'grad_norm': grad_norm.item(),
+ },
+ global_step,
+ )
+ logging.info(
+ '%d) loss = %0.3f, grad_norm = %0.3f',
+ global_step,
+ loss.item(),
+ grad_norm.item(),
+ )
return (optimizer_state, current_param_container, new_model_state)
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -196,14 +215,16 @@ def get_batch_size(workload_name):
raise ValueError(f'Unsupported workload name: {workload_name}.')
-def data_selection(workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+def data_selection(
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
diff --git a/reference_algorithms/paper_baselines/sam/jax/submission.py b/reference_algorithms/paper_baselines/sam/jax/submission.py
index b76589705..ce6db3ac3 100644
--- a/reference_algorithms/paper_baselines/sam/jax/submission.py
+++ b/reference_algorithms/paper_baselines/sam/jax/submission.py
@@ -23,7 +23,8 @@ def dual_vector(y: jnp.ndarray) -> jnp.ndarray:
y: A pytree of numpy ndarray, vector y in the equation above.
"""
gradient_norm = jnp.sqrt(
- sum(jnp.sum(jnp.square(e)) for e in jax.tree_util.tree_leaves(y)))
+ sum(jnp.sum(jnp.square(e)) for e in jax.tree_util.tree_leaves(y))
+ )
normalized_gradient = jax.tree.map(lambda x: x / gradient_norm, y)
return normalized_gradient
@@ -31,11 +32,11 @@ def dual_vector(y: jnp.ndarray) -> jnp.ndarray:
# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/
# sharpness_aware_minimization.py
def sharpness_aware_minimization(
- rho: float,
- grad_clip: Optional[float],
- batch_axis_name: str,
- base_opt_init_fn,
- base_opt_update_fn,
+ rho: float,
+ grad_clip: Optional[float],
+ batch_axis_name: str,
+ base_opt_init_fn,
+ base_opt_update_fn,
) -> optax.GradientTransformation:
"""Implementation of Sharpness Aware Minimization (SAM).
Paper: https://arxiv.org/abs/2010.01412
@@ -68,22 +69,28 @@ def update_fn(updates, state, grad_fn_params_tuple):
# with the same 1e-6 epsilon that is used when clipping the gradients.
updates = dual_vector(updates)
noised_params = jax.tree_util.tree_map(
- lambda p, u: p + rho * u, params, updates)
+ lambda p, u: p + rho * u, params, updates
+ )
(_, (n_valid_examples, _)), updates = grad_fn(noised_params)
# Get correct global mean grad.
- (n_valid_examples, updates) = lax.psum((n_valid_examples, updates),
- axis_name=batch_axis_name)
+ (n_valid_examples, updates) = lax.psum(
+ (n_valid_examples, updates), axis_name=batch_axis_name
+ )
updates = jax.tree.map(lambda x: x / n_valid_examples, updates)
if grad_clip:
updates_norm = jnp.sqrt(
- sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(updates)))
+ sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(updates))
+ )
scaled_updates = jax.tree.map(
- lambda x: x / (updates_norm + _GRAD_CLIP_EPS) * grad_clip, updates)
- updates = jax.lax.cond(updates_norm > grad_clip,
- lambda _: scaled_updates,
- lambda _: updates,
- None)
+ lambda x: x / (updates_norm + _GRAD_CLIP_EPS) * grad_clip, updates
+ )
+ updates = jax.lax.cond(
+ updates_norm > grad_clip,
+ lambda _: scaled_updates,
+ lambda _: updates,
+ None,
+ )
updates, state = base_opt_update_fn(updates, state, params)
return updates, state
@@ -91,11 +98,13 @@ def update_fn(updates, state, grad_fn_params_tuple):
return optax.GradientTransformation(init_fn, update_fn)
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a SAM optimizer (with AdamW base) and a learning rate schedule."""
del model_params
del model_state
@@ -105,111 +114,127 @@ def jax_cosine_warmup(step_hint: int, hyperparameters):
# Create learning rate schedule.
warmup_steps = int(hyperparameters.warmup_factor * step_hint)
warmup_fn = optax.linear_schedule(
- init_value=0.,
- end_value=hyperparameters.learning_rate,
- transition_steps=warmup_steps)
+ init_value=0.0,
+ end_value=hyperparameters.learning_rate,
+ transition_steps=warmup_steps,
+ )
cosine_steps = max(step_hint - warmup_steps, 1)
cosine_fn = optax.cosine_decay_schedule(
- init_value=hyperparameters.learning_rate, decay_steps=cosine_steps)
+ init_value=hyperparameters.learning_rate, decay_steps=cosine_steps
+ )
schedule_fn = optax.join_schedules(
- schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps])
+ schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]
+ )
return schedule_fn
# Create base optimizer + LR schedule.
lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters)
opt_init_fn, opt_update_fn = optax.adamw(
- learning_rate=lr_schedule_fn,
- b1=1.0 - hyperparameters.one_minus_beta1,
- b2=hyperparameters.beta2,
- eps=1e-8,
- weight_decay=hyperparameters.weight_decay)
+ learning_rate=lr_schedule_fn,
+ b1=1.0 - hyperparameters.one_minus_beta1,
+ b2=hyperparameters.beta2,
+ eps=1e-8,
+ weight_decay=hyperparameters.weight_decay,
+ )
# Create SAM update fn.
grad_clip = (
- hyperparameters.grad_clip
- if hasattr(hyperparameters, 'grad_clip') else None)
+ hyperparameters.grad_clip if hasattr(hyperparameters, 'grad_clip') else None
+ )
opt_init_fn, opt_update_fn = sharpness_aware_minimization(
- rho=hyperparameters.rho,
- grad_clip=grad_clip,
- batch_axis_name='batch',
- base_opt_init_fn=opt_init_fn,
- base_opt_update_fn=opt_update_fn)
+ rho=hyperparameters.rho,
+ grad_clip=grad_clip,
+ batch_axis_name='batch',
+ base_opt_init_fn=opt_init_fn,
+ base_opt_update_fn=opt_update_fn,
+ )
# Initialize optimizer state.
- params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple),
- workload.param_shapes)
+ params_zeros_like = jax.tree.map(
+ lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes
+ )
optimizer_state = opt_init_fn(params_zeros_like)
return jax_utils.replicate(optimizer_state), opt_update_fn
@functools.partial(
- jax.pmap,
- axis_name='batch',
- in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
- static_broadcasted_argnums=(0, 1),
- donate_argnums=(2, 3, 4))
-def pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- batch,
- rng,
- grad_clip,
- label_smoothing):
-
+ jax.pmap,
+ axis_name='batch',
+ in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
+ static_broadcasted_argnums=(0, 1),
+ donate_argnums=(2, 3, 4),
+)
+def pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ rng,
+ grad_clip,
+ label_smoothing,
+):
def _loss_fn(params, update_batch_norm=True):
"""Loss function used for training."""
logits, new_model_state = workload.model_fn(
- params,
- batch,
- model_state,
- spec.ForwardPassMode.TRAIN,
- rng,
- update_batch_norm=update_batch_norm)
+ params,
+ batch,
+ model_state,
+ spec.ForwardPassMode.TRAIN,
+ rng,
+ update_batch_norm=update_batch_norm,
+ )
loss_dict = workload.loss_fn(
- label_batch=batch['targets'],
- logits_batch=logits,
- mask_batch=batch.get('weights'),
- label_smoothing=label_smoothing)
+ label_batch=batch['targets'],
+ logits_batch=logits,
+ mask_batch=batch.get('weights'),
+ label_smoothing=label_smoothing,
+ )
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
return summed_loss, (n_valid_examples, new_model_state)
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
second_grad_fn = jax.value_and_grad(
- functools.partial(_loss_fn, update_batch_norm=False), has_aux=True)
+ functools.partial(_loss_fn, update_batch_norm=False), has_aux=True
+ )
(summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn(
- current_param_container)
+ current_param_container
+ )
# Get correct global mean loss and grad.
(summed_loss, n_valid_examples, grad) = lax.psum(
- (summed_loss, n_valid_examples, grad), axis_name='batch')
+ (summed_loss, n_valid_examples, grad), axis_name='batch'
+ )
loss = summed_loss / n_valid_examples
grad = jax.tree.map(lambda x: x / n_valid_examples, grad)
grad_norm = jnp.sqrt(
- sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)))
+ sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))
+ )
updates, new_optimizer_state = opt_update_fn(
- grad, optimizer_state, (second_grad_fn, current_param_container))
+ grad, optimizer_state, (second_grad_fn, current_param_container)
+ )
updated_params = optax.apply_updates(current_param_container, updates)
return new_optimizer_state, updated_params, new_model_state, loss, grad_norm
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -226,37 +251,43 @@ def update_params(
grad_clip = hyperparameters.grad_clip
else:
grad_clip = None
- outputs = pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- batch,
- per_device_rngs,
- grad_clip,
- label_smoothing)
+ outputs = pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ per_device_rngs,
+ grad_clip,
+ label_smoothing,
+ )
new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs
# Log loss, grad_norm.
if global_step % 100 == 0 and workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
- {
- 'loss': loss[0],
- 'grad_norm': grad_norm[0],
- }, global_step)
+ {
+ 'loss': loss[0],
+ 'grad_norm': grad_norm[0],
+ },
+ global_step,
+ )
return (new_optimizer_state, opt_update_fn), new_params, new_model_state
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -293,14 +324,15 @@ def get_batch_size(workload_name):
def data_selection(
- workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Tuple[spec.Tensor, spec.Tensor, spec.Tensor]:
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Tuple[spec.Tensor, spec.Tensor, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
diff --git a/reference_algorithms/paper_baselines/sam/pytorch/submission.py b/reference_algorithms/paper_baselines/sam/pytorch/submission.py
index 92603f036..fdd4eb8b7 100644
--- a/reference_algorithms/paper_baselines/sam/pytorch/submission.py
+++ b/reference_algorithms/paper_baselines/sam/pytorch/submission.py
@@ -17,13 +17,14 @@
# Modified from https://github.com/davda54/sam.
class SAM(torch.optim.Optimizer):
-
- def __init__(self,
- params: spec.ParameterContainer,
- base_optimizer: torch.optim.Optimizer,
- rho: float = 0.05,
- adaptive: bool = False,
- **kwargs):
+ def __init__(
+ self,
+ params: spec.ParameterContainer,
+ base_optimizer: torch.optim.Optimizer,
+ rho: float = 0.05,
+ adaptive: bool = False,
+ **kwargs,
+ ):
if rho < 0.0:
raise ValueError(f'Invalid rho, should be non-negative: {rho}')
@@ -79,12 +80,18 @@ def _grad_norm(self):
# In case of model parallelism, put everything on the same device.
shared_device = self.param_groups[0]['params'][0].device
norm = torch.norm(
- torch.stack([((torch.abs(p) if group['adaptive'] else 1.0) *
- p.grad).norm(p=2).to(shared_device)
- for group in self.param_groups
- for p in group['params']
- if p.grad is not None]),
- p=2)
+ torch.stack(
+ [
+ ((torch.abs(p) if group['adaptive'] else 1.0) * p.grad)
+ .norm(p=2)
+ .to(shared_device)
+ for group in self.param_groups
+ for p in group['params']
+ if p.grad is not None
+ ]
+ ),
+ p=2,
+ )
return norm
def load_state_dict(self, state_dict: Dict):
@@ -92,11 +99,13 @@ def load_state_dict(self, state_dict: Dict):
self.base_optimizer.param_groups = self.param_groups
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates an AdamW optimizer and a learning rate schedule."""
del model_state
del rng
@@ -104,46 +113,50 @@ def init_optimizer_state(workload: spec.Workload,
# Create SAM optimizer with AdamW base.
base_optimizer = torch.optim.AdamW
optimizer_state = {
- 'optimizer':
- SAM(model_params.parameters(),
- base_optimizer=base_optimizer,
- rho=hyperparameters.rho,
- lr=hyperparameters.learning_rate,
- betas=(1.0 - hyperparameters.one_minus_beta1,
- hyperparameters.beta2),
- eps=1e-8,
- weight_decay=hyperparameters.weight_decay),
+ 'optimizer': SAM(
+ model_params.parameters(),
+ base_optimizer=base_optimizer,
+ rho=hyperparameters.rho,
+ lr=hyperparameters.learning_rate,
+ betas=(1.0 - hyperparameters.one_minus_beta1, hyperparameters.beta2),
+ eps=1e-8,
+ weight_decay=hyperparameters.weight_decay,
+ ),
}
def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer):
warmup_steps = int(hyperparameters.warmup_factor * step_hint)
warmup = LinearLR(
- optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps)
+ optimizer, start_factor=1e-10, end_factor=1.0, total_iters=warmup_steps
+ )
cosine_steps = max(step_hint - warmup_steps, 1)
cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps)
return SequentialLR(
- optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps])
+ optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]
+ )
# Create learning rate schedule.
optimizer_state['scheduler'] = pytorch_cosine_warmup(
- workload.step_hint, hyperparameters, optimizer_state['optimizer'])
+ workload.step_hint, hyperparameters, optimizer_state['optimizer']
+ )
return optimizer_state
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -156,20 +169,24 @@ def update_params(
def _loss_fn(params, update_batch_norm=True):
"""Loss function used for training."""
logits_batch, new_model_state = workload.model_fn(
- params=params,
- augmented_and_preprocessed_input_batch=batch,
- model_state=model_state,
- mode=spec.ForwardPassMode.TRAIN,
- rng=rng,
- update_batch_norm=update_batch_norm)
+ params=params,
+ augmented_and_preprocessed_input_batch=batch,
+ model_state=model_state,
+ mode=spec.ForwardPassMode.TRAIN,
+ rng=rng,
+ update_batch_norm=update_batch_norm,
+ )
label_smoothing = (
- hyperparameters.label_smoothing if hasattr(hyperparameters,
- 'label_smoothing') else 0.0)
+ hyperparameters.label_smoothing
+ if hasattr(hyperparameters, 'label_smoothing')
+ else 0.0
+ )
loss_dict = workload.loss_fn(
- label_batch=batch['targets'],
- logits_batch=logits_batch,
- mask_batch=batch.get('weights'),
- label_smoothing=label_smoothing)
+ label_batch=batch['targets'],
+ logits_batch=logits_batch,
+ mask_batch=batch.get('weights'),
+ label_smoothing=label_smoothing,
+ )
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
if USE_PYTORCH_DDP:
@@ -187,7 +204,8 @@ def _loss_fn(params, update_batch_norm=True):
with torch.no_grad():
parameters = [p for p in current_model.parameters() if p.grad is not None]
grad_norm = torch.norm(
- torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2)
+ torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2
+ )
optimizer_state['optimizer'].first_step(zero_grad=True)
@@ -198,7 +216,8 @@ def _loss_fn(params, update_batch_norm=True):
if hasattr(hyperparameters, 'grad_clip'):
grad_clip = hyperparameters.grad_clip
torch.nn.utils.clip_grad_norm_(
- current_model.parameters(), max_norm=grad_clip)
+ current_model.parameters(), max_norm=grad_clip
+ )
optimizer_state['optimizer'].second_step(zero_grad=True)
optimizer_state['scheduler'].step()
@@ -206,29 +225,34 @@ def _loss_fn(params, update_batch_norm=True):
if global_step <= 100 or global_step % 500 == 0:
if workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
- {
- 'loss': logging_loss.item(),
- 'grad_norm': grad_norm.item(),
- },
- global_step)
- logging.info('%d) loss = %0.3f, grad_norm = %0.3f',
- global_step,
- logging_loss.item(),
- grad_norm.item())
+ {
+ 'loss': logging_loss.item(),
+ 'grad_norm': grad_norm.item(),
+ },
+ global_step,
+ )
+ logging.info(
+ '%d) loss = %0.3f, grad_norm = %0.3f',
+ global_step,
+ logging_loss.item(),
+ grad_norm.item(),
+ )
return (optimizer_state, current_param_container, new_model_state)
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -265,14 +289,15 @@ def get_batch_size(workload_name):
def data_selection(
- workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Tuple[spec.Tensor, spec.Tensor, spec.Tensor]:
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Tuple[spec.Tensor, spec.Tensor, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
diff --git a/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py b/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py
index a5c2732ac..830dd4816 100644
--- a/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py
+++ b/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py
@@ -61,13 +61,16 @@
@struct.dataclass
class QuantizedValue:
"""State associated with quantized value."""
+
quantized: chex.Array
diagonal: chex.Array # Diagonal (if extract_diagonal is set)
bucket_size: chex.Array
quantized_dtype: jnp.dtype = struct.field(
- pytree_node=False) # Dtype for the quantized value.
+ pytree_node=False
+ ) # Dtype for the quantized value.
extract_diagonal: bool = struct.field(
- pytree_node=False) # In case its centered.
+ pytree_node=False
+ ) # In case its centered.
shape: Any = struct.field(pytree_node=False) # Shape of the tensor.
@classmethod
@@ -75,13 +78,16 @@ def from_float_value(cls, fvalue, quantized_dtype, extract_diagonal=False):
if isinstance(fvalue, list) and not fvalue:
return QuantizedValue([], [], [], quantized_dtype, extract_diagonal, [])
quantized, diagonal_fvalue, bucket_size = QuantizedValue.quantize(
- fvalue, quantized_dtype, extract_diagonal)
- return QuantizedValue(quantized,
- diagonal_fvalue,
- bucket_size,
- quantized_dtype,
- extract_diagonal,
- list(quantized.shape))
+ fvalue, quantized_dtype, extract_diagonal
+ )
+ return QuantizedValue(
+ quantized,
+ diagonal_fvalue,
+ bucket_size,
+ quantized_dtype,
+ extract_diagonal,
+ list(quantized.shape),
+ )
# Quantization is from Lingvo JAX optimizers.
# We extend it for int16 quantization of PSD matrices.
@@ -106,7 +112,8 @@ def quantize(cls, fvalue, quantized_dtype, extract_diagonal=False):
if extract_diagonal and fvalue.ndim != 2:
raise ValueError(
- f'Input array {fvalue} must be 2D to work with extract_diagonal.')
+ f'Input array {fvalue} must be 2D to work with extract_diagonal.'
+ )
diagonal_fvalue = []
if extract_diagonal:
@@ -119,16 +126,17 @@ def quantize(cls, fvalue, quantized_dtype, extract_diagonal=False):
# We first decide the scale.
if fvalue.ndim < 1:
raise ValueError(
- f'Input array {fvalue} must have a strictly positive number of '
- 'dimensions.')
+ f'Input array {fvalue} must have a strictly positive number of '
+ 'dimensions.'
+ )
max_abs = jnp.max(jnp.abs(fvalue), axis=0)
bucket_size = max_abs / num_buckets
bs_expanded = bucket_size[jnp.newaxis, Ellipsis]
# To avoid divide by 0.0
- bs_nonzero = jnp.where(bs_expanded > 0.0,
- bs_expanded,
- jnp.ones_like(bs_expanded))
+ bs_nonzero = jnp.where(
+ bs_expanded > 0.0, bs_expanded, jnp.ones_like(bs_expanded)
+ )
ratio = fvalue / bs_nonzero
# We use rounding to remove bias.
quantized = jnp.round(ratio)
@@ -155,10 +163,11 @@ def to_float(self):
def _default_zero_field():
return struct.field(
- default_factory=functools.partial(jnp.array, 0, jnp.float32))
+ default_factory=functools.partial(jnp.array, 0, jnp.float32)
+ )
-T = TypeVar("T")
+T = TypeVar('T')
def _maybe_ix(ls, ix):
@@ -180,17 +189,19 @@ def wrap_f(x, *args, **kwargs):
InversePthRootDiagnosticsSubtype = TypeVar(
- "InversePthRootDiagnosticsSubtype", bound="InversePthRootDiagnostics")
+ 'InversePthRootDiagnosticsSubtype', bound='InversePthRootDiagnostics'
+)
@struct.dataclass
class InversePthRootDiagnostics:
"""Diagnostics for inverse p-th root iterative procedure.
- Given an inverse pth root B = A^(-1/p), contains the average and
- maximum diagonal and off diagonal absolute entrywise errors between
- (B^p A) and I.
- """
+ Given an inverse pth root B = A^(-1/p), contains the average and
+ maximum diagonal and off diagonal absolute entrywise errors between
+ (B^p A) and I.
+ """
+
max_diag_error: chex.Array = _default_zero_field()
avg_diag_error: chex.Array = _default_zero_field()
max_off_diag_error: chex.Array = _default_zero_field()
@@ -201,35 +212,41 @@ class InversePthRootDiagnostics:
def create(cls, pth_inverse_root, matrix, p):
"""Generates a diagnostics struct from (-1/p) root result."""
mat_m = jnp.matmul(
- mat_power(pth_inverse_root, p),
- matrix,
- precision=jax.lax.Precision.HIGHEST)
+ mat_power(pth_inverse_root, p),
+ matrix,
+ precision=jax.lax.Precision.HIGHEST,
+ )
num_off_diag_entries = mat_m.size - jnp.diag(mat_m).size
diag_error = jnp.abs(jnp.diag(mat_m) - 1).astype(jnp.float32)
off_diag_error = jnp.abs(mat_m - jnp.diag(jnp.diag(mat_m))).astype(
- jnp.float32)
+ jnp.float32
+ )
return cls(
- max_diag_error=jnp.max(diag_error).astype(jnp.float32),
- avg_diag_error=jnp.mean(diag_error).astype(jnp.float32),
- max_off_diag_error=jnp.max(off_diag_error).astype(jnp.float32),
- avg_off_diag_error=(jnp.sum(off_diag_error) /
- num_off_diag_entries).astype(jnp.float32),
- p=jnp.array(p, jnp.float32))
+ max_diag_error=jnp.max(diag_error).astype(jnp.float32),
+ avg_diag_error=jnp.mean(diag_error).astype(jnp.float32),
+ max_off_diag_error=jnp.max(off_diag_error).astype(jnp.float32),
+ avg_off_diag_error=(
+ jnp.sum(off_diag_error) / num_off_diag_entries
+ ).astype(jnp.float32),
+ p=jnp.array(p, jnp.float32),
+ )
LOBPCGDiagnosticsSubtype = TypeVar(
- "LOBPCGDiagnosticsSubtype", bound="LOBPCGDiagnostics")
+ 'LOBPCGDiagnosticsSubtype', bound='LOBPCGDiagnostics'
+)
@struct.dataclass
class LOBPCGDiagnostics:
"""Diagnostics for iterative LOBPCG eigenvalue routine.
- Contains consistency error for LOBPCG eigenvalue routine, which
- refers to |A v - lambda v| / (lambda + |A v|) for a proposed eigenpair
- (v, lambda). This metics dataclass retains consistency error
- and other useful LOBPCG values.
- """
+ Contains consistency error for LOBPCG eigenvalue routine, which
+ refers to |A v - lambda v| / (lambda + |A v|) for a proposed eigenpair
+ (v, lambda). This metics dataclass retains consistency error
+ and other useful LOBPCG values.
+ """
+
lobpcg_iters: chex.Array = _default_zero_field()
max_consistency_error: chex.Array = _default_zero_field()
avg_consistency_error: chex.Array = _default_zero_field()
@@ -248,7 +265,8 @@ def create(cls, matrix, eigvals, eigvecs, lobpcg_iters):
mat_eigvecs = matrix.dot(eigvecs, precision=precision)
consistency_error_unnormalized = jnp.linalg.norm(
- mat_eigvecs - eigvals * eigvecs, ord=2, axis=0)
+ mat_eigvecs - eigvals * eigvecs, ord=2, axis=0
+ )
normalization = jnp.linalg.norm(mat_eigvecs, ord=2, axis=0) + eigvals
consistency_error = consistency_error_unnormalized / normalization
@@ -256,20 +274,22 @@ def create(cls, matrix, eigvals, eigvecs, lobpcg_iters):
orthogonality_error -= jnp.diag(jnp.diag(orthogonality_error))
return cls(
- lobpcg_iters=jnp.array(lobpcg_iters, jnp.float32),
- max_consistency_error=jnp.max(consistency_error).astype(jnp.float32),
- avg_consistency_error=jnp.mean(consistency_error).astype(jnp.float32),
- avg_orthogonality_error=(jnp.sum(orthogonality_error) /
- num_off_diag).astype(jnp.float32),
- max_eigenvalue=jnp.max(eigvals).astype(jnp.float32),
- min_eigenvalue=jnp.min(eigvals).astype(jnp.float32),
- num_topk_eigenvectors=jnp.array(num_topk, jnp.float32),
+ lobpcg_iters=jnp.array(lobpcg_iters, jnp.float32),
+ max_consistency_error=jnp.max(consistency_error).astype(jnp.float32),
+ avg_consistency_error=jnp.mean(consistency_error).astype(jnp.float32),
+ avg_orthogonality_error=(
+ jnp.sum(orthogonality_error) / num_off_diag
+ ).astype(jnp.float32),
+ max_eigenvalue=jnp.max(eigvals).astype(jnp.float32),
+ min_eigenvalue=jnp.min(eigvals).astype(jnp.float32),
+ num_topk_eigenvectors=jnp.array(num_topk, jnp.float32),
)
@struct.dataclass
class TrainingMetrics:
"""Diagnostic metrics from training."""
+
# Error for inverse-pth roots.
inverse_pth_root_errors: chex.Array = _default_zero_field()
# Iteration count for inverse-pth roots.
@@ -283,20 +303,24 @@ class TrainingMetrics:
total_retries: chex.Array = _default_zero_field()
lobpcg_diagnostics: LOBPCGDiagnostics = struct.field(
- default_factory=LOBPCGDiagnostics)
+ default_factory=LOBPCGDiagnostics
+ )
# Rich matrix entrywise error diagnostics, if enabled.
inverse_pth_root_diagnostics: InversePthRootDiagnostics = struct.field(
- default_factory=InversePthRootDiagnostics)
+ default_factory=InversePthRootDiagnostics
+ )
# Diagnostics applied to the conditioned p-th root problem, after top
# eigenvectors are removed, if LOBPCG is being applied.
conditioned_inverse_pth_root_diagnostics: InversePthRootDiagnostics = (
- struct.field(default_factory=InversePthRootDiagnostics))
+ struct.field(default_factory=InversePthRootDiagnostics)
+ )
# TODO(rohananil): Add more important metrics to track during training.
# Per parameter optimizer state used in data-parallel training.
class ParameterStats(NamedTuple):
"""State associated to each parameter of the model being trained."""
+
diagonal_statistics: QuantizedValue # Accumulator for diagonal preconditioner
statistics: Optional[List[Any]] # Statistics (QuantizedValue, chex.Array)
preconditioners: List[Any] # Preconditioners (QuantizedValue, chex.Array)
@@ -321,12 +345,14 @@ class GlobalShardedParameterStats:
@struct.dataclass
class LocalShardedParameterStats:
"""State associated to each parameter of the model being trained."""
+
diagonal_statistics: QuantizedValue # Accumulator for diagonal preconditioner
diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner
momentum: QuantizedValue # Momentum for the shampoo preconditioner
training_metrics: Union[TrainingMetrics, optax.MaskedNode]
index_start: Union[np.int32, int] = struct.field(
- pytree_node=False) # Index into global statistics array
+ pytree_node=False
+ ) # Index into global statistics array
sizes: Any = struct.field(pytree_node=False) # Sizes of the statistics.
@@ -336,39 +362,44 @@ def default_training_metrics():
def init_training_metrics(
- num_statistics,
- generate_training_metrics,
+ num_statistics,
+ generate_training_metrics,
):
"""Initialize TrainingMetrics, masked if disabled."""
if not generate_training_metrics:
return optax.MaskedNode()
return jax.tree.map(
- functools.partial(jnp.repeat, repeats=num_statistics),
- default_training_metrics())
+ functools.partial(jnp.repeat, repeats=num_statistics),
+ default_training_metrics(),
+ )
def init_training_metrics_shapes(
- num_statistics,
- generate_training_metrics,
+ num_statistics,
+ generate_training_metrics,
):
"""Initialize training metrics shape/dtype."""
seed = init_training_metrics(
- num_statistics,
- generate_training_metrics,
+ num_statistics,
+ generate_training_metrics,
)
return jax.tree.map(lambda arr: [list(arr.shape), arr.dtype], seed)
-def init_training_metrics_pspec(generate_training_metrics,):
+def init_training_metrics_pspec(
+ generate_training_metrics,
+):
"""Initialize training metrics partition specification."""
if not generate_training_metrics:
return optax.MaskedNode()
- return jax.tree.map(lambda _: jax.sharding.PartitionSpec(),
- default_training_metrics())
+ return jax.tree.map(
+ lambda _: jax.sharding.PartitionSpec(), default_training_metrics()
+ )
class ShardedShampooStats(NamedTuple):
"""Shampoo state in sharded mode."""
+
global_stats: Any
local_stats: Any
@@ -406,35 +437,35 @@ class PreconditionerType(enum.IntEnum):
def power_iteration(
- matrix,
- num_iters=100,
- error_tolerance=1e-6,
- precision=lax.Precision.HIGHEST,
- padding_start=None,
+ matrix,
+ num_iters=100,
+ error_tolerance=1e-6,
+ precision=lax.Precision.HIGHEST,
+ padding_start=None,
):
r"""Power iteration algorithm.
- The power iteration algorithm takes a symmetric PSD matrix `A`, and produces
- a scalar `\lambda` , which is the greatest (in absolute value) eigenvalue
- of `A`, and a vector v, which is the corresponding eigenvector of `A`.
-
- References:
- [Wikipedia, 2021](https://en.wikipedia.org/wiki/Power_iteration)
-
- Args:
- matrix: the symmetric PSD matrix.
- num_iters: Number of iterations.
- error_tolerance: Iterative exit condition.
- precision: precision XLA related flag, the available options are: a)
- lax.Precision.DEFAULT (better step time, but not precise) b)
- lax.Precision.HIGH (increased precision, slower) c)
- lax.Precision.HIGHEST (best possible precision, slowest)
- padding_start: if set, assumes rows and columns after padding_start are
- zero.
-
- Returns:
- eigen vector, eigen value
- """
+ The power iteration algorithm takes a symmetric PSD matrix `A`, and produces
+ a scalar `\lambda` , which is the greatest (in absolute value) eigenvalue
+ of `A`, and a vector v, which is the corresponding eigenvector of `A`.
+
+ References:
+ [Wikipedia, 2021](https://en.wikipedia.org/wiki/Power_iteration)
+
+ Args:
+ matrix: the symmetric PSD matrix.
+ num_iters: Number of iterations.
+ error_tolerance: Iterative exit condition.
+ precision: precision XLA related flag, the available options are: a)
+ lax.Precision.DEFAULT (better step time, but not precise) b)
+ lax.Precision.HIGH (increased precision, slower) c)
+ lax.Precision.HIGHEST (best possible precision, slowest)
+ padding_start: if set, assumes rows and columns after padding_start are
+ zero.
+
+ Returns:
+ eigen vector, eigen value
+ """
matrix_size = matrix.shape[-1]
def _iter_condition(state):
@@ -446,32 +477,38 @@ def _iter_body(state):
i, new_v, s, s_v, unused_run_step = state
new_v = new_v / jnp.linalg.norm(new_v)
- s_v = jnp.einsum("ij,j->i", matrix, new_v, precision=precision)
- s_new = jnp.einsum("i,i->", new_v, s_v, precision=precision)
- return (i + 1,
- s_v,
- s_new,
- s_v,
- jnp.greater(jnp.abs(s_new - s), error_tolerance))
+ s_v = jnp.einsum('ij,j->i', matrix, new_v, precision=precision)
+ s_new = jnp.einsum('i,i->', new_v, s_v, precision=precision)
+ return (
+ i + 1,
+ s_v,
+ s_new,
+ s_v,
+ jnp.greater(jnp.abs(s_new - s), error_tolerance),
+ )
# Figure out how to use step as seed for random.
- v_0 = np.random.RandomState(1729).uniform(-1.0, 1.0,
- matrix_size).astype(matrix.dtype)
+ v_0 = (
+ np.random.RandomState(1729)
+ .uniform(-1.0, 1.0, matrix_size)
+ .astype(matrix.dtype)
+ )
v_0 = jnp.array(v_0)
if padding_start is not None:
- v_0 *= (jnp.arange(len(v_0), dtype=jnp.int32) < padding_start)
+ v_0 *= jnp.arange(len(v_0), dtype=jnp.int32) < padding_start
init_state = tuple([0, v_0, jnp.zeros([], dtype=matrix.dtype), v_0, True])
- _, v_out, s_out, _, _ = lax.while_loop(_iter_condition, _iter_body,
- init_state)
+ _, v_out, s_out, _, _ = lax.while_loop(
+ _iter_condition, _iter_body, init_state
+ )
v_out = v_out / jnp.linalg.norm(v_out)
return v_out, s_out
def mat_power(
- mat_m,
- p,
- precision=lax.Precision.HIGHEST,
+ mat_m,
+ p,
+ precision=lax.Precision.HIGHEST,
):
"""A simple matrix power method. M^p where p can be TracedValue."""
power = jnp.eye(mat_m.shape[0], dtype=_MAT_INV_PTH_ROOT_DTYPE)
@@ -483,9 +520,11 @@ def _iter_condition(state):
def _iter_body(state):
i, power, mat = state
- power = jax.lax.cond(i % 2 == 1,
- lambda: jnp.matmul(mat, power, precision=precision),
- lambda: power)
+ power = jax.lax.cond(
+ i % 2 == 1,
+ lambda: jnp.matmul(mat, power, precision=precision),
+ lambda: power,
+ )
i //= 2
mat = jnp.matmul(mat, mat, precision=precision)
return i, power, mat
@@ -508,78 +547,81 @@ def _stable_subtract(b, a_minus_b):
return (b**exp) * jnp.expm1(exp * jnp.log1p(a_minus_b / b))
return jnp.where(
- # Choose the branch with the best log1p approximation.
- jnp.abs(a_minus_b / b) < jnp.abs(a_minus_b / a),
- -_stable_subtract(a, -a_minus_b),
- _stable_subtract(b, a_minus_b))
+ # Choose the branch with the best log1p approximation.
+ jnp.abs(a_minus_b / b) < jnp.abs(a_minus_b / a),
+ -_stable_subtract(a, -a_minus_b),
+ _stable_subtract(b, a_minus_b),
+ )
def matrix_inverse_pth_root(
- matrix,
- p,
- num_iters=100,
- ridge_epsilon=1e-6,
- error_tolerance=1e-6,
- precision=lax.Precision.HIGHEST,
- relative_matrix_epsilon=True,
- lobpcg_topk_precondition=0,
- lobpcg_max_iter=0,
- padding_start=None,
- prev=None,
- eigh=False,
+ matrix,
+ p,
+ num_iters=100,
+ ridge_epsilon=1e-6,
+ error_tolerance=1e-6,
+ precision=lax.Precision.HIGHEST,
+ relative_matrix_epsilon=True,
+ lobpcg_topk_precondition=0,
+ lobpcg_max_iter=0,
+ padding_start=None,
+ prev=None,
+ eigh=False,
):
"""Computes `matrix^(-1/p)`, where `p` is a positive integer.
- This function uses the Eigh or Coupled newton iterations algorithm for
- the computation of a matrix's inverse pth root.
-
-
- References:
- [Functions of Matrices, Theory and Computation,
- Nicholas J Higham, Pg 184, Eq 7.18](
- https://epubs.siam.org/doi/book/10.1137/1.9780898717778)
-
- Args:
- matrix: the symmetric PSD matrix whose power it to be computed
- p: exponent, for p a positive integer.
- num_iters: Maximum number of iterations.
- ridge_epsilon: Ridge epsilon added to make the matrix positive definite.
- error_tolerance: Error indicator, useful for early termination.
- precision: precision XLA related flag, the available options are: a)
- lax.Precision.DEFAULT (better step time, but not precise) b)
- lax.Precision.HIGH (increased precision, slower) c)
- lax.Precision.HIGHEST (best possible precision, slowest)
- relative_matrix_epsilon: Whether to use relative epsilon to the max eigen
- value when computing inverse-pth root.
- lobpcg_topk_precondition: If nonzero, specifies the number of top
- eigenvectors to subtract out before performing LOBPCG. Note this makes
- relative_matrix_epsilon essentially free.
- lobpcg_max_iter: Maximum iteration count for LOBPCG, defaults to
- `lobpcg_topk_precondition`.
- padding_start: If the input matrix was padded, then zeros out columns and
- rows at the padding start.
- prev: previous iteration's solution, zero-padded (unused)
- eigh: If True, uses eigh for inverse-pth root computation.
-
- Returns:
- `(matrix + eps)^(-1/p)` and error metrics.
-
- Note `eps` is not added to zeroed out padding rows and
- columns. `eps` is just `ridge_epsilon` if
- `relative_matrix_epsilon` is set to `False`, otherwise, it is the
- ridge epsilon value scaled by the derived maximum eigenvalue of
- the input matrix.
- """
+ This function uses the Eigh or Coupled newton iterations algorithm for
+ the computation of a matrix's inverse pth root.
+
+
+ References:
+ [Functions of Matrices, Theory and Computation,
+ Nicholas J Higham, Pg 184, Eq 7.18](
+ https://epubs.siam.org/doi/book/10.1137/1.9780898717778)
+
+ Args:
+ matrix: the symmetric PSD matrix whose power it to be computed
+ p: exponent, for p a positive integer.
+ num_iters: Maximum number of iterations.
+ ridge_epsilon: Ridge epsilon added to make the matrix positive definite.
+ error_tolerance: Error indicator, useful for early termination.
+ precision: precision XLA related flag, the available options are: a)
+ lax.Precision.DEFAULT (better step time, but not precise) b)
+ lax.Precision.HIGH (increased precision, slower) c)
+ lax.Precision.HIGHEST (best possible precision, slowest)
+ relative_matrix_epsilon: Whether to use relative epsilon to the max eigen
+ value when computing inverse-pth root.
+ lobpcg_topk_precondition: If nonzero, specifies the number of top
+ eigenvectors to subtract out before performing LOBPCG. Note this makes
+ relative_matrix_epsilon essentially free.
+ lobpcg_max_iter: Maximum iteration count for LOBPCG, defaults to
+ `lobpcg_topk_precondition`.
+ padding_start: If the input matrix was padded, then zeros out columns and
+ rows at the padding start.
+ prev: previous iteration's solution, zero-padded (unused)
+ eigh: If True, uses eigh for inverse-pth root computation.
+
+ Returns:
+ `(matrix + eps)^(-1/p)` and error metrics.
+
+ Note `eps` is not added to zeroed out padding rows and
+ columns. `eps` is just `ridge_epsilon` if
+ `relative_matrix_epsilon` is set to `False`, otherwise, it is the
+ ridge epsilon value scaled by the derived maximum eigenvalue of
+ the input matrix.
+ """
if eigh:
- return matrix_inverse_pth_root_eigh(matrix,
- p,
- ridge_epsilon,
- error_tolerance,
- precision,
- relative_matrix_epsilon,
- padding_start,
- prev)
+ return matrix_inverse_pth_root_eigh(
+ matrix,
+ p,
+ ridge_epsilon,
+ error_tolerance,
+ precision,
+ relative_matrix_epsilon,
+ padding_start,
+ prev,
+ )
del prev
assert matrix.shape[0] == matrix.shape[1]
@@ -596,7 +638,8 @@ def matrix_inverse_pth_root(
if padding_start is not None:
# Zero out padding in identity as well for convergence checks.
ix = (jnp.arange(matrix_size, dtype=jnp.int32) < padding_start).astype(
- matrix.dtype)
+ matrix.dtype
+ )
matrix *= ix[jnp.newaxis, :]
matrix *= ix[:, jnp.newaxis]
identity *= ix
@@ -607,18 +650,23 @@ def matrix_inverse_pth_root(
eigvals, eigvecs, lobpcg_diagnostics = None, None, None
if lobpcg_topk_precondition > 0:
# TODO(vladf): reuse previous top-k as the initial search directions
- pad_shape = (matrix_size - lobpcg_topk_precondition,
- lobpcg_topk_precondition)
+ pad_shape = (
+ matrix_size - lobpcg_topk_precondition,
+ lobpcg_topk_precondition,
+ )
search_dirs = jnp.concatenate(
- (jnp.eye(lobpcg_topk_precondition), jnp.zeros(pad_shape)), axis=0)
+ (jnp.eye(lobpcg_topk_precondition), jnp.zeros(pad_shape)), axis=0
+ )
eigvals, eigvecs, lobpcg_iters = linalg.lobpcg_standard( # pylint: disable=unbalanced-tuple-unpacking
- matrix, search_dirs,
- lobpcg_topk_precondition if lobpcg_max_iter == 0 else lobpcg_max_iter)
+ matrix,
+ search_dirs,
+ lobpcg_topk_precondition if lobpcg_max_iter == 0 else lobpcg_max_iter,
+ )
lobpcg_diagnostics = LOBPCGDiagnostics.create(
- matrix,
- eigvals,
- eigvecs,
- lobpcg_iters,
+ matrix,
+ eigvals,
+ eigvecs,
+ lobpcg_iters,
)
# The minimal eigenvalue among top-k becomes the maximal one in the whole
@@ -628,7 +676,8 @@ def matrix_inverse_pth_root(
# Deflate out top eigenvectors to reduce matrix condition number.
matrix -= scaled_vecs.dot(
- scaled_vecs.T, precision=jax.lax.Precision.HIGHEST)
+ scaled_vecs.T, precision=jax.lax.Precision.HIGHEST
+ )
if relative_matrix_epsilon:
if eigvals is not None:
@@ -637,11 +686,12 @@ def matrix_inverse_pth_root(
# Only use power iteration if lobpcg wasn't already used to derive the
# top eigenvalue.
_, max_ev = power_iteration(
- matrix=matrix,
- num_iters=100,
- error_tolerance=1e-6,
- precision=precision,
- padding_start=padding_start)
+ matrix=matrix,
+ num_iters=100,
+ error_tolerance=1e-6,
+ precision=precision,
+ padding_start=padding_start,
+ )
else:
# Use absolute matrix epsilon scaling otherwise.
max_ev = 1.0
@@ -654,8 +704,9 @@ def matrix_inverse_pth_root(
def _iter_condition(state):
i, unused_mat_m, unused_mat_h, unused_old_mat_h, error, error_ratio = state
- error_above_threshold = jnp.logical_and(error > error_tolerance,
- error_ratio < max_error_ratio)
+ error_above_threshold = jnp.logical_and(
+ error > error_tolerance, error_ratio < max_error_ratio
+ )
return jnp.logical_and(i < num_iters, error_above_threshold)
def _iter_body(state):
@@ -673,7 +724,6 @@ def _iter_body(state):
iters = 0
error_ratio = 0.0
else:
-
retry_loop_error_threshold = 0.05
num_tries = 6
init_outer_state = tuple([0, identity, 1000.0, 100, 1.0, True])
@@ -691,23 +741,26 @@ def _outer_body_fn(state):
new_error = jnp.max(jnp.abs(new_mat_m_0 - identity))
new_mat_h_0 = identity * jnp.power(z, 1.0 / p)
init_state = tuple(
- [0, new_mat_m_0, new_mat_h_0, new_mat_h_0, new_error, 1.0])
+ [0, new_mat_m_0, new_mat_h_0, new_mat_h_0, new_error, 1.0]
+ )
iters, mat_m, mat_h, old_mat_h, error, error_ratio = lax.while_loop(
- _iter_condition, _iter_body, init_state)
+ _iter_condition, _iter_body, init_state
+ )
error = jnp.max(jnp.abs(mat_m - identity)).astype(jnp.float32)
is_converged = jnp.asarray(error_ratio < max_error_ratio, old_mat_h.dtype)
- resultant_mat_h = is_converged * \
- mat_h + (1 - is_converged) * old_mat_h
- return (i + 1,
- resultant_mat_h,
- error,
- iters,
- error_ratio,
- error > retry_loop_error_threshold)
-
- loop_outputs = jax.lax.while_loop(_outer_iter_condition_fn,
- _outer_body_fn,
- init_outer_state)
+ resultant_mat_h = is_converged * mat_h + (1 - is_converged) * old_mat_h
+ return (
+ i + 1,
+ resultant_mat_h,
+ error,
+ iters,
+ error_ratio,
+ error > retry_loop_error_threshold,
+ )
+
+ loop_outputs = jax.lax.while_loop(
+ _outer_iter_condition_fn, _outer_body_fn, init_outer_state
+ )
total_retries, resultant_mat_h, error, iters, error_ratio, _ = loop_outputs
conditioned_resultant_mat = resultant_mat_h
@@ -723,35 +776,39 @@ def _outer_body_fn(state):
pth_diff = _pth_root_difference(ridge_epsilon, jnp.min(eigvals), eigvals, p)
scaled_vecs = eigvecs * jnp.sqrt(pth_diff)
resultant_mat_h = conditioned_resultant_mat - scaled_vecs.dot(
- scaled_vecs.T, precision=jax.lax.Precision.HIGHEST)
+ scaled_vecs.T, precision=jax.lax.Precision.HIGHEST
+ )
error_metrics = TrainingMetrics(
- inverse_pth_root_errors=jnp.array(error, jnp.float32),
- inverse_pth_root_iters=jnp.array(iters, jnp.float32),
- final_error_ratio=jnp.array(error_ratio, jnp.float32),
- max_eigen_value=jnp.array(max_ev, jnp.float32),
- total_retries=jnp.array(total_retries, jnp.float32))
+ inverse_pth_root_errors=jnp.array(error, jnp.float32),
+ inverse_pth_root_iters=jnp.array(iters, jnp.float32),
+ final_error_ratio=jnp.array(error_ratio, jnp.float32),
+ max_eigen_value=jnp.array(max_ev, jnp.float32),
+ total_retries=jnp.array(total_retries, jnp.float32),
+ )
if lobpcg_topk_precondition > 0:
- damped_matrix = matrix + \
- (ridge_epsilon * (10**total_retries) * identity)
+ damped_matrix = matrix + (ridge_epsilon * (10**total_retries) * identity)
conditioned_diagnostics = InversePthRootDiagnostics.create(
- conditioned_resultant_mat, damped_matrix, p)
+ conditioned_resultant_mat, damped_matrix, p
+ )
unconditioned_damped_matrix = original_matrix + ridge_epsilon * identity
unconditioned_diagnostics = InversePthRootDiagnostics.create(
- resultant_mat_h, unconditioned_damped_matrix, p)
+ resultant_mat_h, unconditioned_damped_matrix, p
+ )
# The max entrywise error in error_metrics.inverse_pth_root_errors refers
# to what was derived from the inverse pth root iteration, which with
# LOBPCG refers to the conditioned problem. Make sure to use the error
# from the unconditioned problem.
unconditional_errors = jnp.maximum(
- unconditioned_diagnostics.max_diag_error,
- unconditioned_diagnostics.max_off_diag_error)
+ unconditioned_diagnostics.max_diag_error,
+ unconditioned_diagnostics.max_off_diag_error,
+ )
error_metrics = error_metrics.replace(
- inverse_pth_root_errors=unconditional_errors,
- lobpcg_diagnostics=lobpcg_diagnostics,
- conditioned_inverse_pth_root_diagnostics=conditioned_diagnostics,
- inverse_pth_root_diagnostics=unconditioned_diagnostics,
+ inverse_pth_root_errors=unconditional_errors,
+ lobpcg_diagnostics=lobpcg_diagnostics,
+ conditioned_inverse_pth_root_diagnostics=conditioned_diagnostics,
+ inverse_pth_root_diagnostics=unconditioned_diagnostics,
)
if padding_start is not None:
@@ -759,9 +816,9 @@ def _outer_body_fn(state):
# due to some TPU hosts not having the same number of preconditioning
# matrices.
resultant_mat_h = jnp.where(padding_start == 0, 0.0, resultant_mat_h)
- error = jnp.where(padding_start == 0,
- 0.0,
- error_metrics.inverse_pth_root_errors)
+ error = jnp.where(
+ padding_start == 0, 0.0, error_metrics.inverse_pth_root_errors
+ )
error_metrics = error_metrics.replace(inverse_pth_root_errors=error)
resultant_mat_h = jnp.asarray(resultant_mat_h, orig_dtype)
@@ -769,44 +826,44 @@ def _outer_body_fn(state):
def matrix_inverse_pth_root_eigh(
- matrix,
- p,
- ridge_epsilon=1e-6,
- error_tolerance=1e-6,
- precision=lax.Precision.HIGHEST,
- relative_matrix_epsilon=True,
- padding_start=None,
- prev=None,
+ matrix,
+ p,
+ ridge_epsilon=1e-6,
+ error_tolerance=1e-6,
+ precision=lax.Precision.HIGHEST,
+ relative_matrix_epsilon=True,
+ padding_start=None,
+ prev=None,
):
"""Computes `matrix^(-1/p)`, where `p` is a positive integer.
- This function uses eigh for the computation of a matrix's inverse pth
- root.
-
- Args:
- matrix: the symmetric PSD matrix whose power it to be computed
- p: exponent, for p a positive integer.
- ridge_epsilon: Ridge epsilon added to make the matrix positive definite.
- error_tolerance: Error indicator, useful for early termination.
- precision: precision XLA related flag, the available options are: a)
- lax.Precision.DEFAULT (better step time, but not precise) b)
- lax.Precision.HIGH (increased precision, slower) c)
- lax.Precision.HIGHEST (best possible precision, slowest)
- relative_matrix_epsilon: Whether to use relative epsilon to the max eigen
- value when computing inverse-pth root.
- padding_start: If the input matrix was padded, then zeros out columns and
- rows at the padding start.
- prev: previous iteration's solution, zero-padded (unused)
-
- Returns:
- `(matrix + eps)^(-1/p)` and error metrics.
-
- Note `eps` is not added to zeroed out padding rows and
- columns. `eps` is just `ridge_epsilon` if
- `relative_matrix_epsilon` is set to `False`, otherwise, it is the
- ridge epsilon value scaled by the derived maximum eigenvalue of
- the input matrix.
- """
+ This function uses eigh for the computation of a matrix's inverse pth
+ root.
+
+ Args:
+ matrix: the symmetric PSD matrix whose power it to be computed
+ p: exponent, for p a positive integer.
+ ridge_epsilon: Ridge epsilon added to make the matrix positive definite.
+ error_tolerance: Error indicator, useful for early termination.
+ precision: precision XLA related flag, the available options are: a)
+ lax.Precision.DEFAULT (better step time, but not precise) b)
+ lax.Precision.HIGH (increased precision, slower) c)
+ lax.Precision.HIGHEST (best possible precision, slowest)
+ relative_matrix_epsilon: Whether to use relative epsilon to the max eigen
+ value when computing inverse-pth root.
+ padding_start: If the input matrix was padded, then zeros out columns and
+ rows at the padding start.
+ prev: previous iteration's solution, zero-padded (unused)
+
+ Returns:
+ `(matrix + eps)^(-1/p)` and error metrics.
+
+ Note `eps` is not added to zeroed out padding rows and
+ columns. `eps` is just `ridge_epsilon` if
+ `relative_matrix_epsilon` is set to `False`, otherwise, it is the
+ ridge epsilon value scaled by the derived maximum eigenvalue of
+ the input matrix.
+ """
del prev
assert matrix.shape[0] == matrix.shape[1]
matrix_size = matrix.shape[0]
@@ -816,17 +873,19 @@ def matrix_inverse_pth_root_eigh(
identity = jnp.eye(matrix_size, dtype=_MAT_INV_PTH_ROOT_DTYPE)
if padding_start is not None:
ix = (jnp.arange(matrix_size, dtype=jnp.int32) < padding_start).astype(
- matrix.dtype)
+ matrix.dtype
+ )
matrix *= ix[jnp.newaxis, :]
matrix *= ix[:, jnp.newaxis]
identity *= ix
if relative_matrix_epsilon:
_, max_ev = power_iteration(
- matrix=matrix,
- num_iters=100,
- error_tolerance=error_tolerance,
- precision=precision,
- padding_start=padding_start)
+ matrix=matrix,
+ num_iters=100,
+ error_tolerance=error_tolerance,
+ precision=precision,
+ padding_start=padding_start,
+ )
else:
# Use absolute matrix epsilon scaling otherwise.
max_ev = 1.0
@@ -837,9 +896,9 @@ def matrix_inverse_pth_root_eigh(
if padding_start is not None:
e *= jnp.flip(ix)
mm = functools.partial(jnp.matmul, precision=precision)
- inv_e = jnp.where(e == 0.0,
- 0.0,
- jnp.power(jnp.maximum(e, ridge_epsilon), alpha))
+ inv_e = jnp.where(
+ e == 0.0, 0.0, jnp.power(jnp.maximum(e, ridge_epsilon), alpha)
+ )
val = mm(mm(u, jnp.diag(inv_e)), u.T)
root = u * jnp.sqrt(inv_e)
val = mm(root, root.T)
@@ -849,12 +908,13 @@ def matrix_inverse_pth_root_eigh(
eig_error *= jnp.flip(ix)
error = jnp.max(jnp.abs(eig_error))
error_metrics = TrainingMetrics(
- inverse_pth_root_errors=jnp.array(error, jnp.float32))
+ inverse_pth_root_errors=jnp.array(error, jnp.float32)
+ )
if padding_start is not None:
val = jnp.where(padding_start == 0, 0.0, val)
- error = jnp.where(padding_start == 0,
- 0.0,
- error_metrics.inverse_pth_root_errors)
+ error = jnp.where(
+ padding_start == 0, 0.0, error_metrics.inverse_pth_root_errors
+ )
error_metrics = error_metrics.replace(inverse_pth_root_errors=error)
val = jnp.asarray(val, orig_dtype)
return val, error_metrics
@@ -863,17 +923,17 @@ def matrix_inverse_pth_root_eigh(
def merge_small_dims(shape_to_merge, max_dim):
"""Merge small dimensions.
- If there are some small dimensions, we collapse them:
- e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if max_dim = 1024
- [1, 2, 768, 1, 2048] --> [2, 768, 2048]
+ If there are some small dimensions, we collapse them:
+ e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if max_dim = 1024
+ [1, 2, 768, 1, 2048] --> [2, 768, 2048]
- Args:
- shape_to_merge: Shape to merge small dimensions.
- max_dim: Maximal dimension of output shape used in merging.
+ Args:
+ shape_to_merge: Shape to merge small dimensions.
+ max_dim: Maximal dimension of output shape used in merging.
- Returns:
- Merged shape.
- """
+ Returns:
+ Merged shape.
+ """
if shape_to_merge and np.all(np.array(shape_to_merge) == 1):
return [1]
@@ -894,20 +954,23 @@ def merge_small_dims(shape_to_merge, max_dim):
def pad_square_matrix(mat, max_size):
"""Pad a square matrix up to max_size.
- Args:
- mat: a matrix to pad.
- max_size: matrix size requested.
+ Args:
+ mat: a matrix to pad.
+ max_size: matrix size requested.
- Returns:
- Given M returns [[M, 0], [0, I]]
- """
+ Returns:
+ Given M returns [[M, 0], [0, I]]
+ """
rows, cols = mat.shape
if rows != cols:
- raise ValueError("Must have rows == cols, instead got "
- f"rows={rows}, cols={cols}")
+ raise ValueError(
+ f'Must have rows == cols, instead got rows={rows}, cols={cols}'
+ )
if cols > max_size:
- raise ValueError("Must have cols <= max_size. Instead got "
- f"cols={cols}, max_size={max_size}.")
+ raise ValueError(
+ 'Must have cols <= max_size. Instead got '
+ f'cols={cols}, max_size={max_size}.'
+ )
if rows == max_size:
return mat
pad_size = max_size - rows
@@ -923,13 +986,13 @@ def pad_square_matrix(mat, max_size):
def pad_vector(vec, max_size):
"""Pad a vector to a max_size.
- Args:
- vec: a vector to pad.
- max_size: matrix size requested.
+ Args:
+ vec: a vector to pad.
+ max_size: matrix size requested.
- Returns:
- Given V returns [V, 0]
- """
+ Returns:
+ Given V returns [V, 0]
+ """
size = vec.shape[0]
assert size <= max_size
if size == max_size:
@@ -949,9 +1012,9 @@ def _iter_body(unused_state):
def _iter_condition(state):
return state[0]
- results = jax.lax.while_loop(_iter_condition,
- _iter_body,
- tuple([predicate] + init_state))
+ results = jax.lax.while_loop(
+ _iter_condition, _iter_body, tuple([predicate] + init_state)
+ )
return tuple(results[1:])
@@ -985,7 +1048,7 @@ def partition(self, tensor):
assert tensor.shape == self._shape
tensors = [tensor]
- for (i, indices) in self._splits:
+ for i, indices in self._splits:
tensors_local = []
for t in tensors:
tensors_local.extend(jnp.split(t, indices_or_sections=indices, axis=i))
@@ -995,13 +1058,14 @@ def partition(self, tensor):
def merge_partitions(self, partitions):
"""Merge partitions back to original shape."""
- for (i, indices) in reversed(self._splits):
+ for i, indices in reversed(self._splits):
n = len(indices) + 1
partial_merged_tensors = []
ind = 0
while ind < len(partitions):
partial_merged_tensors.append(
- jnp.concatenate(partitions[ind:ind + n], axis=i))
+ jnp.concatenate(partitions[ind : ind + n], axis=i)
+ )
ind += n
partitions = partial_merged_tensors
assert len(partitions) == 1
@@ -1011,25 +1075,25 @@ def merge_partitions(self, partitions):
def gram_weighted_update(old_stats, g, axis, w1, w2, precision=None):
"""Updated statistics via weighted average with new Gram matrix.
- Returns w₁ R + w₂ Gᵀ G where R is `old_stats` and G is the matrix whose
- columns are the flattened slices of the tensor `g` along the given `axis`.
- (So, `old_stats` and the returned matrix have dimensions n x n where
- n = `g.shape[axis]`).
-
- Args:
- old_stats: Old statistics.
- g: Gradient tensor.
- axis: Axis along which to slice `g`.
- w1: Scalar weight for old statistics.
- w2: Scalar weight for new Gram matrix.
- precision: Optional precision XLA related flag, the available options are:
- a) lax.Precision.DEFAULT (better step time, but not precise) b)
- lax.Precision.HIGH (increased precision, slower) c)
- lax.Precision.HIGHEST (best possible precision, slowest)
-
- Returns:
- Weighted average of old and new statistics.
- """
+ Returns w₁ R + w₂ Gᵀ G where R is `old_stats` and G is the matrix whose
+ columns are the flattened slices of the tensor `g` along the given `axis`.
+ (So, `old_stats` and the returned matrix have dimensions n x n where
+ n = `g.shape[axis]`).
+
+ Args:
+ old_stats: Old statistics.
+ g: Gradient tensor.
+ axis: Axis along which to slice `g`.
+ w1: Scalar weight for old statistics.
+ w2: Scalar weight for new Gram matrix.
+ precision: Optional precision XLA related flag, the available options are:
+ a) lax.Precision.DEFAULT (better step time, but not precise) b)
+ lax.Precision.HIGH (increased precision, slower) c)
+ lax.Precision.HIGHEST (best possible precision, slowest)
+
+ Returns:
+ Weighted average of old and new statistics.
+ """
axes = [i for i in range(g.ndim) if i != axis]
gram_matrix = jnp.tensordot(g, g, axes=(axes, axes), precision=precision)
return w1 * old_stats + w2 * gram_matrix
@@ -1039,67 +1103,68 @@ class Preconditioner:
"""Compute statistics/shape from gradients for preconditioning."""
def __init__(
- self,
- param,
- block_size,
- merge_small_dims_block_size,
- best_effort_shape_interpretation,
- preconditioner_type=PreconditionerType.ALL,
+ self,
+ param,
+ block_size,
+ merge_small_dims_block_size,
+ best_effort_shape_interpretation,
+ preconditioner_type=PreconditionerType.ALL,
):
"""Initializes the preconditioner.
- Args:
- param: parameter to precondition.
- block_size: Block size used to split param.
- merge_small_dims_block_size: Block size for merging dims.
- best_effort_shape_interpretation: Whether to
- collapse/merge dims together.
- preconditioner_type: Type of preconditioner to use.
- """
+ Args:
+ param: parameter to precondition.
+ block_size: Block size used to split param.
+ merge_small_dims_block_size: Block size for merging dims.
+ best_effort_shape_interpretation: Whether to
+ collapse/merge dims together.
+ preconditioner_type: Type of preconditioner to use.
+ """
self._original_shape = param.shape
self._transformed_shape = param.shape
if best_effort_shape_interpretation:
- self._transformed_shape = merge_small_dims(self._original_shape,
- merge_small_dims_block_size)
+ self._transformed_shape = merge_small_dims(
+ self._original_shape, merge_small_dims_block_size
+ )
reshaped_param = jnp.reshape(param, self._transformed_shape)
self._partitioner = BlockPartitioner(reshaped_param, block_size)
self._preconditioner_type = preconditioner_type
def updated_statistics_from_grad(
- self,
- stats,
- grad,
- w1,
- w2,
- to_float=None,
- from_float=None,
- precision=None,
+ self,
+ stats,
+ grad,
+ w1,
+ w2,
+ to_float=None,
+ from_float=None,
+ precision=None,
):
"""Update statistics from gradients.
- Args:
- stats: Old statistics or its Cholesky factor if `cholesky` is True.
- grad: Gradient to compute statistics from.
- w1: Weight for old statistics.
- w2: Weight for new statistics.
- to_float: Optional function for converting stats to floating point.
- from_float: Optional function for converting from floating point.
- precision: Optional precision XLA related flag, the available options
- are:
- a) lax.Precision.DEFAULT (better step time, but not precise) b)
- lax.Precision.HIGH (increased precision, slower) c)
- lax.Precision.HIGHEST (best possible precision, slowest)
-
- Returns:
- A list of updated gradient statistics for each partition.
- """
+ Args:
+ stats: Old statistics or its Cholesky factor if `cholesky` is True.
+ grad: Gradient to compute statistics from.
+ w1: Weight for old statistics.
+ w2: Weight for new statistics.
+ to_float: Optional function for converting stats to floating point.
+ from_float: Optional function for converting from floating point.
+ precision: Optional precision XLA related flag, the available options
+ are:
+ a) lax.Precision.DEFAULT (better step time, but not precise) b)
+ lax.Precision.HIGH (increased precision, slower) c)
+ lax.Precision.HIGHEST (best possible precision, slowest)
+
+ Returns:
+ A list of updated gradient statistics for each partition.
+ """
to_float = to_float if to_float is not None else (lambda x: x)
from_float = from_float if from_float is not None else (lambda x: x)
reshaped_grad = jnp.reshape(grad, self._transformed_shape)
partitioned_grads = self._partitioner.partition(reshaped_grad)
should_preconditioned_dims = self.should_precondition_dims()
preconditioned_dims = [
- i for i, p in enumerate(should_preconditioned_dims) if p
+ i for i, p in enumerate(should_preconditioned_dims) if p
]
new_stats = []
index = 0
@@ -1136,8 +1201,7 @@ def _preconds_for_grad(self, preconditioners, rank, start, end):
elif self._preconditioner_type == PreconditionerType.OUTPUT:
# When _preconditioner_type is OUTPUT, we append (rank - 1) many None
# values to the beginning of the list to handle the False indices.
- preconditioners_for_grad = [None] * \
- (rank - 1) + preconditioners_for_grad
+ preconditioners_for_grad = [None] * (rank - 1) + preconditioners_for_grad
assert len(preconditioners_for_grad) == rank
return preconditioners_for_grad
@@ -1165,13 +1229,13 @@ def exponent_for_preconditioner(self):
def preconditioned_grad(self, grad, preconditioners):
"""Precondition the gradient.
- Args:
- grad: A gradient tensor to precondition.
- preconditioners: A list of preconditioners to apply.
+ Args:
+ grad: A gradient tensor to precondition.
+ preconditioners: A list of preconditioners to apply.
- Returns:
- A preconditioned gradient.
- """
+ Returns:
+ A preconditioned gradient.
+ """
reshaped_grad = jnp.reshape(grad, self._transformed_shape)
partitioned_grads = self._partitioner.partition(reshaped_grad)
should_preconditioned_dims = self.should_precondition_dims()
@@ -1179,17 +1243,18 @@ def preconditioned_grad(self, grad, preconditioners):
preconditioned_partitioned_grads = []
for i, g in enumerate(partitioned_grads):
preconditioners_for_grad = self._preconds_for_grad(
- preconditioners,
- rank=len(should_preconditioned_dims),
- start=i * num_preconditioners,
- end=(i + 1) * num_preconditioners,
+ preconditioners,
+ rank=len(should_preconditioned_dims),
+ start=i * num_preconditioners,
+ end=(i + 1) * num_preconditioners,
+ )
+ precond_g = self._precondition_block(
+ g, should_preconditioned_dims, preconditioners_for_grad
)
- precond_g = self._precondition_block(g,
- should_preconditioned_dims,
- preconditioners_for_grad)
preconditioned_partitioned_grads.append(precond_g)
merged_grad = self._partitioner.merge_partitions(
- preconditioned_partitioned_grads)
+ preconditioned_partitioned_grads
+ )
return jnp.reshape(merged_grad, self._original_shape)
def _precondition_block(self, g, should_precondition_dim, preconditioners):
@@ -1208,9 +1273,9 @@ def _precondition_block(self, g, should_precondition_dim, preconditioners):
return g
-def _convert_to_parameter_stats(global_stats,
- local_stat,
- convert_statistics=True):
+def _convert_to_parameter_stats(
+ global_stats, local_stat, convert_statistics=True
+):
"""Creates parameter stats from sharded stats."""
index_start = int(local_stat.index_start)
index_end = int(len(local_stat.sizes)) + index_start
@@ -1225,24 +1290,24 @@ def _convert_to_parameter_stats(global_stats,
if not convert_statistics:
new_statistics = None
return ParameterStats(
- local_stat.diagonal_statistics,
- new_statistics,
- new_preconditioners,
- local_stat.diagonal_momentum,
- local_stat.momentum,
- local_stat.training_metrics,
+ local_stat.diagonal_statistics,
+ new_statistics,
+ new_preconditioners,
+ local_stat.diagonal_momentum,
+ local_stat.momentum,
+ local_stat.training_metrics,
)
def _convert_from_parameter_stats(parameter_stats, local_stats):
"""Creates sharded stats from paramter stats."""
return LocalShardedParameterStats(
- parameter_stats.diagonal_statistics,
- parameter_stats.diagonal_momentum,
- parameter_stats.momentum,
- parameter_stats.training_metrics,
- local_stats.index_start,
- local_stats.sizes,
+ parameter_stats.diagonal_statistics,
+ parameter_stats.diagonal_momentum,
+ parameter_stats.momentum,
+ parameter_stats.training_metrics,
+ local_stats.index_start,
+ local_stats.sizes,
)
@@ -1258,12 +1323,13 @@ def _add_metrics_into_local_stats(local_stats, metrics, keep_old):
# root calculation to find a new preconditioner, so that TensorBoard curves
# look consistent (otherwise they'd oscillate between NaN and measured
# values).
- per_stat_metrics = efficient_cond(keep_old,
- lambda: [local_stat.training_metrics],
- [per_stat_metrics])[0]
+ per_stat_metrics = efficient_cond(
+ keep_old, lambda: [local_stat.training_metrics], [per_stat_metrics]
+ )[0]
# pylint:enable=cell-var-from-loop
new_local_stats.append(
- local_stat.replace(training_metrics=per_stat_metrics))
+ local_stat.replace(training_metrics=per_stat_metrics)
+ )
return new_local_stats
@@ -1271,7 +1337,7 @@ def batch(x, num_devices):
"""Batch `x` so that so that leading axis is num_devices."""
n = len(x)
b = int(n / num_devices)
- return jnp.stack([jnp.stack(x[idx:idx + b]) for idx in range(0, n, b)])
+ return jnp.stack([jnp.stack(x[idx : idx + b]) for idx in range(0, n, b)])
def unbatch(batched_values):
@@ -1290,162 +1356,168 @@ def unbatch(batched_values):
def distributed_shampoo(
- learning_rate,
- block_size=1024,
- beta1=0.9,
- beta2=0.999,
- diagonal_epsilon=1e-8,
- matrix_epsilon=1e-6,
- weight_decay=0.0,
- start_preconditioning_step=101,
- preconditioning_compute_steps=20,
- statistics_compute_steps=1,
- best_effort_shape_interpretation=True,
- graft_type=GraftingType.RMSPROP_NORMALIZED,
- nesterov=True,
- exponent_override=0,
- # Pass pmap 'batch axis name' in pmap mode.
- batch_axis_name=None,
- # Only set following 3 params in pjit/spmd mode.
- # WARNING: Experimental
- statistics_partition_spec=None,
- preconditioner_partition_spec=None,
- num_devices_for_pjit=None,
- shard_optimizer_states=False,
- ###
- # Experimental memory reduction mode
- best_effort_memory_usage_reduction=True,
- ###
- inverse_failure_threshold=0.1,
- moving_average_for_momentum=True,
- skip_preconditioning_dim_size_gt=0,
- clip_by_scaled_gradient_norm=None,
- precision=lax.Precision.HIGHEST,
- tensordot_precision=None,
- relative_matrix_epsilon=True,
- merge_small_dims_block_size=4096,
- lobpcg_topk_precondition=0,
- lobpcg_max_iter=0,
- precondtioner_type=PreconditionerType.ALL,
- custom_preconditioner=False,
- skip_preconditioning_rank_lt=1,
- decoupled_learning_rate=True,
- decoupled_weight_decay=False,
- generate_training_metrics=True,
- reuse_preconditioner=False,
- eigh=True,
+ learning_rate,
+ block_size=1024,
+ beta1=0.9,
+ beta2=0.999,
+ diagonal_epsilon=1e-8,
+ matrix_epsilon=1e-6,
+ weight_decay=0.0,
+ start_preconditioning_step=101,
+ preconditioning_compute_steps=20,
+ statistics_compute_steps=1,
+ best_effort_shape_interpretation=True,
+ graft_type=GraftingType.RMSPROP_NORMALIZED,
+ nesterov=True,
+ exponent_override=0,
+ # Pass pmap 'batch axis name' in pmap mode.
+ batch_axis_name=None,
+ # Only set following 3 params in pjit/spmd mode.
+ # WARNING: Experimental
+ statistics_partition_spec=None,
+ preconditioner_partition_spec=None,
+ num_devices_for_pjit=None,
+ shard_optimizer_states=False,
+ ###
+ # Experimental memory reduction mode
+ best_effort_memory_usage_reduction=True,
+ ###
+ inverse_failure_threshold=0.1,
+ moving_average_for_momentum=True,
+ skip_preconditioning_dim_size_gt=0,
+ clip_by_scaled_gradient_norm=None,
+ precision=lax.Precision.HIGHEST,
+ tensordot_precision=None,
+ relative_matrix_epsilon=True,
+ merge_small_dims_block_size=4096,
+ lobpcg_topk_precondition=0,
+ lobpcg_max_iter=0,
+ precondtioner_type=PreconditionerType.ALL,
+ custom_preconditioner=False,
+ skip_preconditioning_rank_lt=1,
+ decoupled_learning_rate=True,
+ decoupled_weight_decay=False,
+ generate_training_metrics=True,
+ reuse_preconditioner=False,
+ eigh=True,
):
"""Distributed Shampoo optimizer.
- Distributed Shampoo is a second-order preconditioned method (concretely, a
- variant of full-matrix Adagrad), that provides significant convergence and
- wall-clock time improvements compared to conventional first-order methods,
- and that has been shown to scale to large state-of-the-art deep learning
- models.
-
- References:
- Scalable Second Order Optimization for Deep Learning,
- Rohan Anil, Vineet Gupta, Tomer Koren, Kevin Regan, Yoram Singer
-
- Preprint: https://arxiv.org/abs/2002.09018
-
- Args:
- learning_rate: the step size used to update the parameters.
- block_size: Block size for large layers (if > 0). Preconditioning compute
- operation is cubic in the dimension of the tensor. Block size allows us
- to chunk the layers into sub-layers of maximal dimension dictated by
- this value. Use 128 as default (increase if you have compute budget).
- beta1: momentum parameter.
- beta2: second moment averaging parameter.
- diagonal_epsilon: epsilon for diagonal adagrad (only if layerwise grafting
- to AdaGrad is enabled).
- matrix_epsilon: epsilon to add to statistics before computing inverse pth
- root. If you are running in f32 precision for inverse pth root
- (recommended today) this can go upto 1e-6. If you have latest hardware
- with native f64 precision, set this upto 1e-12.
- weight_decay: Weight decay for regularization.
- start_preconditioning_step: When to start Shampoo update before which
- diagonal update is used. This is because we dont have enough information
- to do stable inverse.
- preconditioning_compute_steps: How often to compute preconditioner.
- Performance tuning params for controlling memory and compute
- requirements.
- Ideally set this and statistics_compute_steps params to 1.
- statistics_compute_steps: How often to compute statistics.
- best_effort_shape_interpretation: If there are some small dimensions,
- collapse them e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if
- block = 1024, [1, 2, 768, 1, 2048] --> [2, 768, 2048]
- graft_type: Grafting is a technique to fix the layerwise scale of Shampoo
- optimizer. This allows us to plugin the Shampoo optimizer into settings
- where SGD/AdaGrad is already well tuned.
- nesterov: Nesterov momentum.
- exponent_override: Override the exponent used in matrix inverse.
- batch_axis_name: labeled axis over pmap for data-parallel training the
- optimizer used for.
- statistics_partition_spec: PartitionSpec to be used in sharded mode.
- preconditioner_partition_spec: PartitionSpec to be used in sharded mode.
- num_devices_for_pjit: Number of devices to parallelize over when using
- pjit.
- shard_optimizer_states: Shard optimizer states to save memory in model
- parallel training.
- best_effort_memory_usage_reduction: Best effort memory usage reduction. -
- diagonal_statistics -> jnp.bfloat16 - momentum buffers (2x)
- -> jnp.int8 - statistics, preconditioners -> jnp.int16 + diagonals
- inverse_failure_threshold: numerics are hard and inverses fail sometimes;
- we determine that using this threshold.
- moving_average_for_momentum: Whether to use moving average for momentum
- instead of exponential moving average.
- skip_preconditioning_dim_size_gt: Skip if preconditioning dim size is
- greater than this value.
- clip_by_scaled_gradient_norm: Clip by scaled gradient norm (only useful
- when using RMSProp Grafting).
- precision: precision XLA related flag, the available options are: a)
- lax.Precision.DEFAULT (better step time, but not precise) b)
- lax.Precision.HIGH (increased precision, slower) c)
- lax.Precision.HIGHEST (best possible precision, slowest)
- tensordot_precision: Optional precision to use for the tensordot operation
- when computing statistics (e.g., G Gᵀ). Same options as `precision`
- above.
- relative_matrix_epsilon: Whether to use relative epsilon to the max eigen
- value when computing inverse-pth root.
- merge_small_dims_block_size: Used as the maximum block size to merge the
- shapes.
- lobpcg_topk_precondition: If nonzero, specifies the number of top
- eigenvectors to subtract out before performing LOBPCG. Note this makes
- relative_matrix_epsilon essentially free.
- lobpcg_max_iter: Number of LOBPCG iterations, if zero defaults to
- `lobpcg_topk_precondition`.
- precondtioner_type: Preconditioner type to select all, left only or right
- only preconditioners.
- skip_preconditioning_rank_lt: Skips preconditioning for parameters with
- rank less than this value.
- decoupled_learning_rate: If True, use decoupled learning rate, otherwise
- couple it with preconditioned gradient computation. (Default True)
- decoupled_weight_decay: If True, use decoupled weight decay, otherwise
- couple with weight decay. (Default False)
- generate_training_metrics: If True, gather training metrics, otherwise
- avoid generating them (to reduce memory usage).
- reuse_preconditioner: If True, pass the previous derived preconditioner
- as a warm start to the next iteratin's inverse pth root computation.
- eigh: If True, and uses eigen decomposition for inverse-pth root.
-
- Returns:
- a GradientTransformation.
- """
+ Distributed Shampoo is a second-order preconditioned method (concretely, a
+ variant of full-matrix Adagrad), that provides significant convergence and
+ wall-clock time improvements compared to conventional first-order methods,
+ and that has been shown to scale to large state-of-the-art deep learning
+ models.
+
+ References:
+ Scalable Second Order Optimization for Deep Learning,
+ Rohan Anil, Vineet Gupta, Tomer Koren, Kevin Regan, Yoram Singer
+
+ Preprint: https://arxiv.org/abs/2002.09018
+
+ Args:
+ learning_rate: the step size used to update the parameters.
+ block_size: Block size for large layers (if > 0). Preconditioning compute
+ operation is cubic in the dimension of the tensor. Block size allows us
+ to chunk the layers into sub-layers of maximal dimension dictated by
+ this value. Use 128 as default (increase if you have compute budget).
+ beta1: momentum parameter.
+ beta2: second moment averaging parameter.
+ diagonal_epsilon: epsilon for diagonal adagrad (only if layerwise grafting
+ to AdaGrad is enabled).
+ matrix_epsilon: epsilon to add to statistics before computing inverse pth
+ root. If you are running in f32 precision for inverse pth root
+ (recommended today) this can go upto 1e-6. If you have latest hardware
+ with native f64 precision, set this upto 1e-12.
+ weight_decay: Weight decay for regularization.
+ start_preconditioning_step: When to start Shampoo update before which
+ diagonal update is used. This is because we dont have enough information
+ to do stable inverse.
+ preconditioning_compute_steps: How often to compute preconditioner.
+ Performance tuning params for controlling memory and compute
+ requirements.
+ Ideally set this and statistics_compute_steps params to 1.
+ statistics_compute_steps: How often to compute statistics.
+ best_effort_shape_interpretation: If there are some small dimensions,
+ collapse them e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if
+ block = 1024, [1, 2, 768, 1, 2048] --> [2, 768, 2048]
+ graft_type: Grafting is a technique to fix the layerwise scale of Shampoo
+ optimizer. This allows us to plugin the Shampoo optimizer into settings
+ where SGD/AdaGrad is already well tuned.
+ nesterov: Nesterov momentum.
+ exponent_override: Override the exponent used in matrix inverse.
+ batch_axis_name: labeled axis over pmap for data-parallel training the
+ optimizer used for.
+ statistics_partition_spec: PartitionSpec to be used in sharded mode.
+ preconditioner_partition_spec: PartitionSpec to be used in sharded mode.
+ num_devices_for_pjit: Number of devices to parallelize over when using
+ pjit.
+ shard_optimizer_states: Shard optimizer states to save memory in model
+ parallel training.
+ best_effort_memory_usage_reduction: Best effort memory usage reduction. -
+ diagonal_statistics -> jnp.bfloat16 - momentum buffers (2x)
+ -> jnp.int8 - statistics, preconditioners -> jnp.int16 + diagonals
+ inverse_failure_threshold: numerics are hard and inverses fail sometimes;
+ we determine that using this threshold.
+ moving_average_for_momentum: Whether to use moving average for momentum
+ instead of exponential moving average.
+ skip_preconditioning_dim_size_gt: Skip if preconditioning dim size is
+ greater than this value.
+ clip_by_scaled_gradient_norm: Clip by scaled gradient norm (only useful
+ when using RMSProp Grafting).
+ precision: precision XLA related flag, the available options are: a)
+ lax.Precision.DEFAULT (better step time, but not precise) b)
+ lax.Precision.HIGH (increased precision, slower) c)
+ lax.Precision.HIGHEST (best possible precision, slowest)
+ tensordot_precision: Optional precision to use for the tensordot operation
+ when computing statistics (e.g., G Gᵀ). Same options as `precision`
+ above.
+ relative_matrix_epsilon: Whether to use relative epsilon to the max eigen
+ value when computing inverse-pth root.
+ merge_small_dims_block_size: Used as the maximum block size to merge the
+ shapes.
+ lobpcg_topk_precondition: If nonzero, specifies the number of top
+ eigenvectors to subtract out before performing LOBPCG. Note this makes
+ relative_matrix_epsilon essentially free.
+ lobpcg_max_iter: Number of LOBPCG iterations, if zero defaults to
+ `lobpcg_topk_precondition`.
+ precondtioner_type: Preconditioner type to select all, left only or right
+ only preconditioners.
+ skip_preconditioning_rank_lt: Skips preconditioning for parameters with
+ rank less than this value.
+ decoupled_learning_rate: If True, use decoupled learning rate, otherwise
+ couple it with preconditioned gradient computation. (Default True)
+ decoupled_weight_decay: If True, use decoupled weight decay, otherwise
+ couple with weight decay. (Default False)
+ generate_training_metrics: If True, gather training metrics, otherwise
+ avoid generating them (to reduce memory usage).
+ reuse_preconditioner: If True, pass the previous derived preconditioner
+ as a warm start to the next iteratin's inverse pth root computation.
+ eigh: If True, and uses eigen decomposition for inverse-pth root.
+
+ Returns:
+ a GradientTransformation.
+ """
reset_frequency = None
def _graft_type_has_diagonal_statistics():
"""Returns True if using diagonal firt order method for grafting."""
return graft_type not in [
- GraftingType.SGD, GraftingType.SQRT_N, GraftingType.NONE
+ GraftingType.SGD,
+ GraftingType.SQRT_N,
+ GraftingType.NONE,
]
def quantized_dtype_for_momentum_buffers(var):
- return jnp.int8 if best_effort_memory_usage_reduction and len(
- var.shape) > 1 else jnp.float32
+ return (
+ jnp.int8
+ if best_effort_memory_usage_reduction and len(var.shape) > 1
+ else jnp.float32
+ )
quantize_second_moment = (
- best_effort_memory_usage_reduction and batch_axis_name)
+ best_effort_memory_usage_reduction and batch_axis_name
+ )
# Preconditioner and statistics are both stores as int16 in this mode.
# We take out the diagonal to make quantization easier.
@@ -1472,19 +1544,20 @@ def _to_float(maybe_quantized):
def preconditioner_from_params(param):
"""Returns a Preconditioner object for given param."""
return Preconditioner(
- param,
- block_size,
- merge_small_dims_block_size,
- best_effort_shape_interpretation,
- precondtioner_type,
+ param,
+ block_size,
+ merge_small_dims_block_size,
+ best_effort_shape_interpretation,
+ precondtioner_type,
)
def precond_dim(max_size):
"""Derives largest preconditioner dimension."""
return max_size
- def pad_and_maybe_zero_preconditioners(preconditioners, total, max_size,
- step):
+ def pad_and_maybe_zero_preconditioners(
+ preconditioners, total, max_size, step
+ ):
"""Pad preconditioners up to total x max_size x precond_dim(max_size)."""
pd = precond_dim(max_size)
@@ -1513,9 +1586,9 @@ def _pad_preconditioner(preconditioner):
def sharded_init_fn(params):
"""Returns optimizer state (for PJIT mode).
- Args:
- params: the parameters that should be updated.
- """
+ Args:
+ params: the parameters that should be updated.
+ """
params_flat, treedef = jax.tree_util.tree_flatten(params)
# Find max size to pad to.
max_size = 0
@@ -1542,21 +1615,22 @@ def sharded_init_fn(params):
sizes = [s[0] for s in shapes]
shapes = preconditioner.shapes_for_preconditioners()
statistics = [
- matrix_epsilon * jnp.eye(max_size, dtype=jnp.float32)
- for s in shapes
+ matrix_epsilon * jnp.eye(max_size, dtype=jnp.float32) for s in shapes
]
pd = precond_dim(max_size)
# If the preconditioner is using a low-rank representation, initialize
# it to zero instead of an invalid eye.
preconditioners = [
- jnp.eye(max_size, pd, dtype=jnp.float32) * (pd == max_size)
- for s in shapes
+ jnp.eye(max_size, pd, dtype=jnp.float32) * (pd == max_size)
+ for s in shapes
]
padded_statistics.extend(statistics)
padded_preconditioners.extend(preconditioners)
exponent = (
- preconditioner.exponent_for_preconditioner()
- if exponent_override == 0 else exponent_override)
+ preconditioner.exponent_for_preconditioner()
+ if exponent_override == 0
+ else exponent_override
+ )
exponents.extend([exponent] * len(shapes))
diagonal_statistics = jnp.zeros_like(param)
@@ -1564,16 +1638,18 @@ def sharded_init_fn(params):
momentum = jnp.zeros_like(param)
local_stats_flat.append(
- LocalShardedParameterStats(
- diagonal_statistics,
- diagonal_momentum,
- momentum,
- init_training_metrics(
- len(sizes),
- generate_training_metrics,
- ),
- index_start,
- sizes))
+ LocalShardedParameterStats(
+ diagonal_statistics,
+ diagonal_momentum,
+ momentum,
+ init_training_metrics(
+ len(sizes),
+ generate_training_metrics,
+ ),
+ index_start,
+ sizes,
+ )
+ )
local_stats = jax.tree_util.tree_unflatten(treedef, local_stats_flat)
to_pad = -len(padded_statistics) % num_devices_for_pjit
@@ -1588,22 +1664,27 @@ def sharded_init_fn(params):
# TODO(rohananil): Relax to only the size of the mesh axis where the dim
# is split on.
padded_statistics.extend(
- [jnp.eye(max_size, dtype=stat_dtype) for _ in range(to_pad)])
+ [jnp.eye(max_size, dtype=stat_dtype) for _ in range(to_pad)]
+ )
pd = precond_dim(max_size)
# If the preconditioner is using a low-rank representation, initialize
# it to zero instead of an invalid eye.
- padded_preconditioners.extend([
+ padded_preconditioners.extend(
+ [
jnp.eye(max_size, pd, dtype=stat_dtype) * (pd == max_size)
for _ in range(to_pad)
- ])
+ ]
+ )
exponents.extend([1 for _ in range(to_pad)])
global_stats = GlobalShardedParameterStats(
- jnp.stack(padded_statistics),
- jnp.stack(padded_preconditioners),
- jnp.stack(exponents))
+ jnp.stack(padded_statistics),
+ jnp.stack(padded_preconditioners),
+ jnp.stack(exponents),
+ )
return ShampooState(
- count=jnp.zeros([], jnp.int32),
- stats=ShardedShampooStats(global_stats, local_stats))
+ count=jnp.zeros([], jnp.int32),
+ stats=ShardedShampooStats(global_stats, local_stats),
+ )
def _max_statistics_size_from_params(params):
max_size = 0
@@ -1624,20 +1705,21 @@ def _remove_leading_sharding_annotation(pspec):
else:
return []
- def sharded_init_partition_spec_fn(params,
- params_partition_spec,
- partition_spec_for_statistics):
+ def sharded_init_partition_spec_fn(
+ params, params_partition_spec, partition_spec_for_statistics
+ ):
"""Returns a parallel state tree with PartitionSpec associated with state.
- Args:
- params: A pytree with params.
- params_partition_spec: A pytree with PartitionSpec for params.
- partition_spec_for_statistics: PartitionSpec for the statistics.
- """
+ Args:
+ params: A pytree with params.
+ params_partition_spec: A pytree with PartitionSpec for params.
+ partition_spec_for_statistics: PartitionSpec for the statistics.
+ """
# Parallel lists of spec, and params.
param_pspec_flat, _ = jax.tree_util.tree_flatten(
- params_partition_spec, is_leaf=lambda x: x is None)
+ params_partition_spec, is_leaf=lambda x: x is None
+ )
params_flat, treedef = jax.tree_util.tree_flatten(params)
assert param_pspec_flat
assert params_flat
@@ -1667,48 +1749,57 @@ def sharded_init_partition_spec_fn(params,
m2_scale_pspec = _remove_leading_sharding_annotation(m2_pspec)
local_stats_flat.append(
- LocalShardedParameterStats(
- QuantizedValue(
- param_pspec,
- [],
- [],
- jnp.float32,
- False, # pytype: disable=wrong-arg-types # numpy-scalars
- list(param.shape)),
- QuantizedValue(
- m1_pspec,
- [],
- m1_scale_pspec,
- qdtype,
- False, # pytype: disable=wrong-arg-types # numpy-scalars
- list(param.shape)),
- QuantizedValue(
- m2_pspec,
- [],
- m2_scale_pspec,
- qdtype,
- False, # pytype: disable=wrong-arg-types # numpy-scalars
- list(param.shape)),
- init_training_metrics_pspec(generate_training_metrics,),
- index_start,
- sizes))
+ LocalShardedParameterStats(
+ QuantizedValue(
+ param_pspec,
+ [],
+ [],
+ jnp.float32,
+ False, # pytype: disable=wrong-arg-types # numpy-scalars
+ list(param.shape),
+ ),
+ QuantizedValue(
+ m1_pspec,
+ [],
+ m1_scale_pspec,
+ qdtype,
+ False, # pytype: disable=wrong-arg-types # numpy-scalars
+ list(param.shape),
+ ),
+ QuantizedValue(
+ m2_pspec,
+ [],
+ m2_scale_pspec,
+ qdtype,
+ False, # pytype: disable=wrong-arg-types # numpy-scalars
+ list(param.shape),
+ ),
+ init_training_metrics_pspec(
+ generate_training_metrics,
+ ),
+ index_start,
+ sizes,
+ )
+ )
local_stats = jax.tree_util.tree_unflatten(treedef, local_stats_flat)
- global_stats = GlobalShardedParameterStats(partition_spec_for_statistics,
- partition_spec_for_statistics,
- jax.sharding.PartitionSpec())
+ global_stats = GlobalShardedParameterStats(
+ partition_spec_for_statistics,
+ partition_spec_for_statistics,
+ jax.sharding.PartitionSpec(),
+ )
count_pspec = jax.sharding.PartitionSpec()
return ShampooState( # pytype: disable=wrong-arg-types # numpy-scalars
- count=count_pspec,
- stats=ShardedShampooStats(global_stats, local_stats))
+ count=count_pspec, stats=ShardedShampooStats(global_stats, local_stats)
+ )
def sharded_init_shape_and_dtype_fn(params):
"""Returns a parallel state tree with shape, dtype associated with state.
- Args:
- params: A pytree with params.
- """
+ Args:
+ params: A pytree with params.
+ """
# Parallel lists of spec, and params.
params_flat, treedef = jax.tree_util.tree_flatten(params)
assert params_flat
@@ -1739,31 +1830,39 @@ def sharded_init_shape_and_dtype_fn(params):
diagonal_statistics_shape_and_dtype = [list(param.shape), param.dtype]
local_stats_flat.append(
- LocalShardedParameterStats(
- QuantizedValue(
- diagonal_statistics_shape_and_dtype,
- [],
- [], # pytype: disable=wrong-arg-types # numpy-scalars
- jnp.float32,
- False,
- list(param.shape)),
- QuantizedValue(m1_shape_and_dtype, [],
- m1_scale_shape_and_dtype,
- qdtype,
- False,
- list(param.shape)),
- QuantizedValue(m2_shape_and_dtype, [],
- m2_scale_shape_and_dtype,
- qdtype,
- False,
- list(param.shape)),
- init_training_metrics_shapes(
- len(sizes),
- generate_training_metrics,
- ),
- index_start,
- sizes,
- ))
+ LocalShardedParameterStats(
+ QuantizedValue(
+ diagonal_statistics_shape_and_dtype,
+ [],
+ [], # pytype: disable=wrong-arg-types # numpy-scalars
+ jnp.float32,
+ False,
+ list(param.shape),
+ ),
+ QuantizedValue(
+ m1_shape_and_dtype,
+ [],
+ m1_scale_shape_and_dtype,
+ qdtype,
+ False,
+ list(param.shape),
+ ),
+ QuantizedValue(
+ m2_shape_and_dtype,
+ [],
+ m2_scale_shape_and_dtype,
+ qdtype,
+ False,
+ list(param.shape),
+ ),
+ init_training_metrics_shapes(
+ len(sizes),
+ generate_training_metrics,
+ ),
+ index_start,
+ sizes,
+ )
+ )
local_stats = jax.tree_util.tree_unflatten(treedef, local_stats_flat)
max_statistics_size = _max_statistics_size_from_params(params_flat)
@@ -1773,29 +1872,36 @@ def sharded_init_shape_and_dtype_fn(params):
num_statistics = num_devices_for_pjit
max_statistics_size = block_size
statistics_shape = [
- num_statistics, max_statistics_size, max_statistics_size
+ num_statistics,
+ max_statistics_size,
+ max_statistics_size,
]
preconditioners_shape = [
- num_statistics, max_statistics_size, precond_dim(max_statistics_size)
+ num_statistics,
+ max_statistics_size,
+ precond_dim(max_statistics_size),
]
global_stats = GlobalShardedParameterStats(
- [statistics_shape, jnp.float32], [preconditioners_shape, jnp.float32],
- [[num_statistics], jnp.int32])
+ [statistics_shape, jnp.float32],
+ [preconditioners_shape, jnp.float32],
+ [[num_statistics], jnp.int32],
+ )
return ShampooState( # pytype: disable=wrong-arg-types # numpy-scalars
- count=[[], jnp.float32],
- stats=ShardedShampooStats(global_stats, local_stats))
+ count=[[], jnp.float32],
+ stats=ShardedShampooStats(global_stats, local_stats),
+ )
def sharded_update_fn(grads, state, params):
"""Transform the input gradient and update all statistics in sharded mode.
- Args:
- grads: the gradient tensors for the parameters.
- state: a named tuple containing the state of the optimizer
- params: the parameters that should be updated.
+ Args:
+ grads: the gradient tensors for the parameters.
+ state: a named tuple containing the state of the optimizer
+ params: the parameters that should be updated.
- Returns:
- A tuple containing the new parameters and the new optimizer state.
- """
+ Returns:
+ A tuple containing the new parameters and the new optimizer state.
+ """
params_flat, treedef = jax.tree_util.tree_flatten(params)
grads_flat = treedef.flatten_up_to(grads)
@@ -1803,43 +1909,45 @@ def sharded_update_fn(grads, state, params):
local_stats_flat = treedef.flatten_up_to(state.stats.local_stats)
stats_flat = []
for local_stat in local_stats_flat:
- stats_flat.append(_convert_to_parameter_stats(
+ stats_flat.append(
+ _convert_to_parameter_stats(
global_stats,
local_stat,
- ))
+ )
+ )
new_stats_flat = jax.tree.map(
- lambda g,
- s,
- p: _compute_stats(g, s, p, state.count),
- grads_flat,
- stats_flat,
- params_flat)
+ lambda g, s, p: _compute_stats(g, s, p, state.count),
+ grads_flat,
+ stats_flat,
+ params_flat,
+ )
outputs = jax.tree.map(
- lambda g,
- s,
- p: _transform_grad(g, s, p, state.count),
- grads_flat,
- new_stats_flat,
- params_flat)
+ lambda g, s, p: _transform_grad(g, s, p, state.count),
+ grads_flat,
+ new_stats_flat,
+ params_flat,
+ )
updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ())
updates = jax.tree_util.tree_unflatten(treedef, updates_flat)
new_local_stats_flat = []
for new_stat, local_stat in zip(new_stats_flat, local_stats_flat):
new_local_stats_flat.append(
- _convert_from_parameter_stats(
- new_stat,
- local_stat,
- ))
+ _convert_from_parameter_stats(
+ new_stat,
+ local_stat,
+ )
+ )
max_size = global_stats.statistics.shape[1]
new_padded_statistics = []
padding_starts = []
for stat in new_stats_flat:
new_padded_statistics.extend(
- [pad_square_matrix(stat, max_size) for stat in stat.statistics])
+ [pad_square_matrix(stat, max_size) for stat in stat.statistics]
+ )
padding_starts.extend([len(stat) for stat in stat.statistics])
# Create global stats
@@ -1857,7 +1965,8 @@ def sharded_update_fn(grads, state, params):
stat_dtype = new_padded_statistics[0].dtype
new_padded_statistics.extend(
- [jnp.eye(max_size, dtype=stat_dtype) for _ in range(to_pad)])
+ [jnp.eye(max_size, dtype=stat_dtype) for _ in range(to_pad)]
+ )
padding_starts += [0] * to_pad
if reuse_preconditioner:
@@ -1865,29 +1974,30 @@ def sharded_update_fn(grads, state, params):
for stat in new_stats_flat:
prev_preconditioners.extend(stat.preconditioners)
prev_padded_preconditioners = pad_and_maybe_zero_preconditioners(
- prev_preconditioners,
- len(new_padded_statistics),
- max_size,
- state.count)
+ prev_preconditioners, len(new_padded_statistics), max_size, state.count
+ )
else:
prev_padded_preconditioners = None
new_stacked_padded_statistics = jnp.stack(new_padded_statistics)
new_stacked_padded_statistics = pjit.with_sharding_constraint(
- new_stacked_padded_statistics, statistics_partition_spec)
+ new_stacked_padded_statistics, statistics_partition_spec
+ )
stacked_padding_starts = jnp.array(padding_starts, jnp.int32)
prev_stacked_padded_preconditioners = _maybe(jnp.stack)(
- prev_padded_preconditioners)
+ prev_padded_preconditioners
+ )
prev_stacked_padded_preconditioners = _maybe(pjit.with_sharding_constraint)(
- prev_padded_preconditioners, statistics_partition_spec)
+ prev_padded_preconditioners, statistics_partition_spec
+ )
def _internal_inverse_pth_root_all():
preconditioners, metrics = _matrix_inverse_pth_root_pjit(
- new_stacked_padded_statistics,
- global_stats.exponents,
- stacked_padding_starts,
- prev_stacked_padded_preconditioners,
- statistics_partition_spec,
+ new_stacked_padded_statistics,
+ global_stats.exponents,
+ stacked_padding_starts,
+ prev_stacked_padded_preconditioners,
+ statistics_partition_spec,
)
return preconditioners, metrics
@@ -1903,39 +2013,47 @@ def _internal_inverse_pth_root_all():
preconditioners_init = new_stacked_padded_statistics[:, :, :pd]
n = new_stacked_padded_statistics.shape[0]
metrics_init = cast(
- TrainingMetrics,
- init_training_metrics(
- n,
- generate_training_metrics=True,
- ))
+ TrainingMetrics,
+ init_training_metrics(
+ n,
+ generate_training_metrics=True,
+ ),
+ )
new_errors = jnp.ones_like(metrics_init.inverse_pth_root_errors) * (
- inverse_failure_threshold)
+ inverse_failure_threshold
+ )
metrics_init = metrics_init.replace(inverse_pth_root_errors=new_errors)
init_state = [preconditioners_init, metrics_init]
new_preconditioners, metrics = efficient_cond(
- perform_step, _internal_inverse_pth_root_all, init_state)
+ perform_step, _internal_inverse_pth_root_all, init_state
+ )
if generate_training_metrics:
new_local_stats_flat = _add_metrics_into_local_stats(
- new_local_stats_flat, metrics, ~perform_step)
- new_local_stats = jax.tree_util.tree_unflatten(treedef,
- new_local_stats_flat)
+ new_local_stats_flat, metrics, ~perform_step
+ )
+ new_local_stats = jax.tree_util.tree_unflatten(
+ treedef, new_local_stats_flat
+ )
errors = metrics.inverse_pth_root_errors
errors = errors.reshape((-1, 1, 1))
predicate = jnp.logical_or(
- jnp.isnan(errors),
- errors >= inverse_failure_threshold).astype(new_preconditioners.dtype)
+ jnp.isnan(errors), errors >= inverse_failure_threshold
+ ).astype(new_preconditioners.dtype)
# TODO(rohananil): Check for numerical instabilities.
new_conditional_preconditioners = (
- predicate * global_stats.preconditioners +
- (1.0 - predicate) * new_preconditioners)
+ predicate * global_stats.preconditioners
+ + (1.0 - predicate) * new_preconditioners
+ )
new_global_stats = GlobalShardedParameterStats(
- new_stacked_padded_statistics,
- new_conditional_preconditioners,
- global_stats.exponents)
+ new_stacked_padded_statistics,
+ new_conditional_preconditioners,
+ global_stats.exponents,
+ )
new_shampoo_state = ShampooState(
- count=state.count + 1,
- stats=ShardedShampooStats(new_global_stats, new_local_stats))
+ count=state.count + 1,
+ stats=ShardedShampooStats(new_global_stats, new_local_stats),
+ )
return updates, new_shampoo_state
def init_fn(params):
@@ -1948,13 +2066,13 @@ def _init(param):
if not _skip_preconditioning(param):
shapes = preconditioner.shapes_for_preconditioners()
statistics = [
- matrix_epsilon * jnp.eye(s[0], dtype=jnp.float32) for s in shapes
+ matrix_epsilon * jnp.eye(s[0], dtype=jnp.float32) for s in shapes
]
# If the preconditioner is using a low-rank representation, initialize
# it to zero instead of an invalid eye.
preconditioners = [
- jnp.eye(s[0], s[1], dtype=jnp.float32) * (s[0] == s[1])
- for s in shapes
+ jnp.eye(s[0], s[1], dtype=jnp.float32) * (s[0] == s[1])
+ for s in shapes
]
diagonal_statistics = []
@@ -1967,25 +2085,28 @@ def _init(param):
momentum = jnp.zeros_like(param)
return ParameterStats(
- diagonal_statistics,
- statistics,
- preconditioners,
- # _quantize_diagonal_statistics(diagonal_statistics),
- # _maybe_quantize_statistics(statistics),
- # _maybe_quantize_preconditioners(preconditioners),
- diagonal_momentum,
- momentum,
- init_training_metrics(
- len(statistics),
- generate_training_metrics,
- ))
+ diagonal_statistics,
+ statistics,
+ preconditioners,
+ # _quantize_diagonal_statistics(diagonal_statistics),
+ # _maybe_quantize_statistics(statistics),
+ # _maybe_quantize_preconditioners(preconditioners),
+ diagonal_momentum,
+ momentum,
+ init_training_metrics(
+ len(statistics),
+ generate_training_metrics,
+ ),
+ )
return ShampooState(
- count=jnp.zeros([], jnp.int32), stats=jax.tree.map(_init, params))
+ count=jnp.zeros([], jnp.int32), stats=jax.tree.map(_init, params)
+ )
def _skip_preconditioning(param):
return len(param.shape) < skip_preconditioning_rank_lt or any(
- s > skip_preconditioning_dim_size_gt for s in param.shape)
+ s > skip_preconditioning_dim_size_gt for s in param.shape
+ )
def _compute_stats(grad, state, param, step):
"""Compute per-parameter statistics."""
@@ -1997,105 +2118,117 @@ def _compute_stats(grad, state, param, step):
def compute_updated_statistics():
return preconditioner.updated_statistics_from_grad(
- state.statistics,
- grad,
- w1=w1,
- w2=w2,
- to_float=_to_float,
- from_float=lambda x: x,
- # from_float=lambda x: _maybe_quantize_statistics([x])[0],
- precision=tensordot_precision,
+ state.statistics,
+ grad,
+ w1=w1,
+ w2=w2,
+ to_float=_to_float,
+ from_float=lambda x: x,
+ # from_float=lambda x: _maybe_quantize_statistics([x])[0],
+ precision=tensordot_precision,
)
if statistics_compute_steps > 1:
perform_step = step % statistics_compute_steps == 0
init_state = state.statistics
new_statistics = list(
- efficient_cond(perform_step, compute_updated_statistics,
- init_state))
+ efficient_cond(perform_step, compute_updated_statistics, init_state)
+ )
else:
new_statistics = compute_updated_statistics()
- return ParameterStats(state.diagonal_statistics,
- new_statistics,
- state.preconditioners,
- state.diagonal_momentum,
- state.momentum,
- state.training_metrics)
+ return ParameterStats(
+ state.diagonal_statistics,
+ new_statistics,
+ state.preconditioners,
+ state.diagonal_momentum,
+ state.momentum,
+ state.training_metrics,
+ )
mi_pth_root = functools.partial(
- matrix_inverse_pth_root,
- ridge_epsilon=matrix_epsilon,
- precision=precision,
- relative_matrix_epsilon=relative_matrix_epsilon,
- lobpcg_topk_precondition=lobpcg_topk_precondition,
- lobpcg_max_iter=lobpcg_max_iter,
- eigh=eigh)
+ matrix_inverse_pth_root,
+ ridge_epsilon=matrix_epsilon,
+ precision=precision,
+ relative_matrix_epsilon=relative_matrix_epsilon,
+ lobpcg_topk_precondition=lobpcg_topk_precondition,
+ lobpcg_max_iter=lobpcg_max_iter,
+ eigh=eigh,
+ )
def _matrix_inverse_pth_root_vmap(xs, ps, padding_starts, prev):
return jax.vmap(mi_pth_root)(
- xs, ps, padding_start=padding_starts, prev=prev)
+ xs, ps, padding_start=padding_starts, prev=prev
+ )
- def _matrix_inverse_pth_root_pjit(xs,
- ps,
- padding_starts,
- prev_preconds=None,
- statistics_partition_spec=None):
+ def _matrix_inverse_pth_root_pjit(
+ xs, ps, padding_starts, prev_preconds=None, statistics_partition_spec=None
+ ):
# Partition the concatenated statistics matrix across all cores.
pspec_for_partition = preconditioner_partition_spec
partitioned_xs = pjit.with_sharding_constraint(xs, pspec_for_partition)
if preconditioner_partition_spec:
partitioned_ps_spec = jax.sharding.PartitionSpec(
- preconditioner_partition_spec[0])
+ preconditioner_partition_spec[0]
+ )
else:
partitioned_ps_spec = None
partitioned_ps = pjit.with_sharding_constraint(ps, partitioned_ps_spec)
partitioned_prev_preconds = _maybe(pjit.with_sharding_constraint)(
- prev_preconds, preconditioner_partition_spec)
+ prev_preconds, preconditioner_partition_spec
+ )
partitioned_padding_starts = pjit.with_sharding_constraint(
- padding_starts, partitioned_ps_spec) # paddings are scalars like ps.
+ padding_starts, partitioned_ps_spec
+ ) # paddings are scalars like ps.
# Run matrix inverse pth root on each shard.
partitioned_preconditioners, partitioned_metrics = (
- _matrix_inverse_pth_root_vmap(
- partitioned_xs,
- partitioned_ps,
- partitioned_padding_starts,
- prev=partitioned_prev_preconds))
+ _matrix_inverse_pth_root_vmap(
+ partitioned_xs,
+ partitioned_ps,
+ partitioned_padding_starts,
+ prev=partitioned_prev_preconds,
+ )
+ )
# Reshard output to have the same PSpec as input. This is required to avoid
# vmap seeing the full set of statistics.
partitioned_preconditioners = pjit.with_sharding_constraint(
- partitioned_preconditioners, pspec_for_partition)
+ partitioned_preconditioners, pspec_for_partition
+ )
# Recombine the outputs at each core.
- preconditioners = pjit.with_sharding_constraint(partitioned_preconditioners,
- statistics_partition_spec)
- metrics = pjit.with_sharding_constraint(partitioned_metrics,
- jax.sharding.PartitionSpec())
+ preconditioners = pjit.with_sharding_constraint(
+ partitioned_preconditioners, statistics_partition_spec
+ )
+ metrics = pjit.with_sharding_constraint(
+ partitioned_metrics, jax.sharding.PartitionSpec()
+ )
return preconditioners, metrics
- def _pmap_compute_preconditioners(states,
- step,
- statistics,
- num_statistics_per_state,
- original_shapes,
- exponents,
- max_size,
- prev_preconditioners):
+ def _pmap_compute_preconditioners(
+ states,
+ step,
+ statistics,
+ num_statistics_per_state,
+ original_shapes,
+ exponents,
+ max_size,
+ prev_preconditioners,
+ ):
"""Computes preconditioners for given statistics in states in PMAP mode.
- Args:
- states: A list of optimizer states.
- step: Current step number
- statistics: A list of statistics for all variables (for every dim)
- num_statistics_per_state: Number of statistis per state to reconstruct
- output states.
- original_shapes: A list of shapes of the statistics.
- exponents: Exponent power to use for inverse-pth roots.
- max_size: Maximum dim of the statistics to pad.
- prev_preconditioners: Previously available preconditioner.
-
- Returns:
- New optimizer states after computing the preconditioner.
- """
+ Args:
+ states: A list of optimizer states.
+ step: Current step number
+ statistics: A list of statistics for all variables (for every dim)
+ num_statistics_per_state: Number of statistis per state to reconstruct
+ output states.
+ original_shapes: A list of shapes of the statistics.
+ exponents: Exponent power to use for inverse-pth roots.
+ max_size: Maximum dim of the statistics to pad.
+ prev_preconditioners: Previously available preconditioner.
+
+ Returns:
+ New optimizer states after computing the preconditioner.
+ """
if batch_axis_name:
num_devices = lax.psum(1, batch_axis_name)
else:
@@ -2103,13 +2236,15 @@ def _pmap_compute_preconditioners(states,
num_statistics = len(statistics)
# Pad statistics and exponents to next multiple of num_devices.
packed_statistics = [
- pad_square_matrix(stat, max_size) for stat in statistics
+ pad_square_matrix(stat, max_size) for stat in statistics
]
to_pad = -num_statistics % num_devices
- packed_statistics.extend([
+ packed_statistics.extend(
+ [
jnp.eye(max_size, dtype=packed_statistics[0].dtype)
for _ in range(to_pad)
- ])
+ ]
+ )
exponents.extend([1 for _ in range(to_pad)])
paddings = [len(stat) for stat in statistics] + [0] * to_pad
@@ -2119,7 +2254,8 @@ def _pmap_compute_preconditioners(states,
if reuse_preconditioner:
assert len(prev_preconditioners) == num_statistics
packed_preconditioners = pad_and_maybe_zero_preconditioners(
- prev_preconditioners, len(packed_statistics), max_size, step)
+ prev_preconditioners, len(packed_statistics), max_size, step
+ )
else:
packed_preconditioners = None
@@ -2132,10 +2268,10 @@ def _internal_inverse_pth_root_all():
if batch_axis_name:
current_replica = lax.axis_index(batch_axis_name)
preconditioners, metrics = _matrix_inverse_pth_root_vmap(
- all_statistics[current_replica],
- all_exponents[current_replica],
- all_paddings[current_replica],
- _maybe_ix(all_preconditioners, current_replica),
+ all_statistics[current_replica],
+ all_exponents[current_replica],
+ all_paddings[current_replica],
+ _maybe_ix(all_preconditioners, current_replica),
)
preconditioners = jax.lax.all_gather(preconditioners, batch_axis_name)
metrics = jax.lax.all_gather(metrics, batch_axis_name)
@@ -2143,14 +2279,15 @@ def _internal_inverse_pth_root_all():
metrics_flat = jax.tree.map(unbatch, metrics)
else:
preconditioners, metrics = _matrix_inverse_pth_root_vmap(
- all_statistics[0],
- all_exponents[0],
- all_paddings[0],
- _maybe_ix(all_preconditioners, 0),
+ all_statistics[0],
+ all_exponents[0],
+ all_paddings[0],
+ _maybe_ix(all_preconditioners, 0),
)
preconditioners_flat = unbatch(jnp.stack([preconditioners]))
metrics = jax.tree.map(
- functools.partial(jnp.expand_dims, axis=0), metrics)
+ functools.partial(jnp.expand_dims, axis=0), metrics
+ )
metrics_flat = jax.tree.map(unbatch, metrics)
return preconditioners_flat, metrics_flat
@@ -2163,40 +2300,57 @@ def _internal_inverse_pth_root_all():
# shaped tensors. Note statistics will be ignored as we are passing in
# a large error value.
preconditioners_init = [
- s[:, :precond_dim(s.shape[0])] for s in packed_statistics
+ s[:, : precond_dim(s.shape[0])] for s in packed_statistics
]
n = len(packed_statistics)
metrics_init = jax.tree.map(
- lambda x: [x] * n,
- default_training_metrics().replace(
- inverse_pth_root_errors=inverse_failure_threshold))
+ lambda x: [x] * n,
+ default_training_metrics().replace(
+ inverse_pth_root_errors=inverse_failure_threshold
+ ),
+ )
init_state = [preconditioners_init, metrics_init]
preconditioners_flat, metrics_flat = efficient_cond(
- perform_step, _internal_inverse_pth_root_all, init_state)
+ perform_step, _internal_inverse_pth_root_all, init_state
+ )
def _skip(error):
condition = jnp.logical_or(
- jnp.isnan(error), error >= inverse_failure_threshold)
+ jnp.isnan(error), error >= inverse_failure_threshold
+ )
return condition.astype(error.dtype)
def _select_preconditioner(error, new_p, old_p):
return lax.cond(
- _skip(error), lambda _: old_p, lambda _: new_p, operand=None)
+ _skip(error), lambda _: old_p, lambda _: new_p, operand=None
+ )
new_preconditioners_flat = []
new_errors_flat = metrics_flat.inverse_pth_root_errors
- for p, shape, prev_p, error in zip(preconditioners_flat, original_shapes,
- prev_preconditioners, new_errors_flat):
+ for p, shape, prev_p, error in zip(
+ preconditioners_flat,
+ original_shapes,
+ prev_preconditioners,
+ new_errors_flat,
+ ):
new_preconditioners_flat.append(
- _select_preconditioner(error, p[:shape[0], :shape[1]], prev_p))
+ _select_preconditioner(error, p[: shape[0], : shape[1]], prev_p)
+ )
- assert len(states) == (len(num_statistics_per_state),
- f"{len(states)} vs {len(num_statistics_per_state)}")
+ assert len(states) == (
+ len(num_statistics_per_state),
+ f'{len(states)} vs {len(num_statistics_per_state)}',
+ )
assert len(new_preconditioners_flat) == num_statistics
assert len(new_errors_flat) == len(packed_statistics), (
- len(new_errors_flat), len(packed_statistics))
+ len(new_errors_flat),
+ len(packed_statistics),
+ )
assert len(new_errors_flat) == num_statistics + to_pad, (
- len(new_errors_flat), num_statistics, to_pad)
+ len(new_errors_flat),
+ num_statistics,
+ to_pad,
+ )
# Add back empty preconditioners so we that we can set the optimizer state.
preconditioners_for_states = []
@@ -2206,26 +2360,31 @@ def _select_preconditioner(error, new_p, old_p):
if num_statistics == 0:
preconditioners_for_states.append([])
metrics_for_states.append(
- init_training_metrics(0, generate_training_metrics))
+ init_training_metrics(0, generate_training_metrics)
+ )
else:
- preconditioners_for_state = new_preconditioners_flat[idx:idx +
- num_statistics]
+ preconditioners_for_state = new_preconditioners_flat[
+ idx : idx + num_statistics
+ ]
assert len(state.statistics) == len(preconditioners_for_state)
preconditioners_for_states.append(preconditioners_for_state)
if generate_training_metrics:
# pylint:disable=cell-var-from-loop Used immediately.
metrics_for_state = jax.tree.map(
- lambda x: jnp.stack(x[idx:idx + num_statistics]),
- metrics_flat,
- is_leaf=lambda x: isinstance(x, list))
+ lambda x: jnp.stack(x[idx : idx + num_statistics]),
+ metrics_flat,
+ is_leaf=lambda x: isinstance(x, list),
+ )
assert jax.tree_util.tree_all(
- jax.tree.map(lambda x: len(state.statistics) == len(x),
- metrics_for_state))
+ jax.tree.map(
+ lambda x: len(state.statistics) == len(x), metrics_for_state
+ )
+ )
# If we skipped preconditioner computation, record old metrics.
- metrics_for_state = efficient_cond(perform_step,
- lambda: [metrics_for_state],
- [state.training_metrics])[0]
+ metrics_for_state = efficient_cond(
+ perform_step, lambda: [metrics_for_state], [state.training_metrics]
+ )[0]
# pylint:enable=cell-var-from-loop
else:
metrics_for_state = optax.MaskedNode()
@@ -2234,32 +2393,36 @@ def _select_preconditioner(error, new_p, old_p):
idx += num_statistics
new_states = []
for state, new_preconditioners, new_metrics in zip(
- states, preconditioners_for_states, metrics_for_states):
+ states, preconditioners_for_states, metrics_for_states
+ ):
# Note the preconditioner may have been skipped, but we still update the
# metrics with the new error values; whether the preconditioner that's
# actively being used is stale can be derived from the new_metrics
# being greater than the failure threshold.
new_states.append(
- ParameterStats(state.diagonal_statistics,
- state.statistics,
- new_preconditioners,
- state.diagonal_momentum,
- state.momentum,
- new_metrics))
+ ParameterStats(
+ state.diagonal_statistics,
+ state.statistics,
+ new_preconditioners,
+ state.diagonal_momentum,
+ state.momentum,
+ new_metrics,
+ )
+ )
return new_states
def _compute_preconditioners(states, params, step):
"""Computes preconditioners for given statistics in states.
- Args:
- states: A list of optimizer states.
- params: A list of params.
- step: Current step number
+ Args:
+ states: A list of optimizer states.
+ params: A list of params.
+ step: Current step number
- Returns:
- New optimizer states after computing the preconditioner.
- """
+ Returns:
+ New optimizer states after computing the preconditioner.
+ """
statistics = []
num_statistics_per_state = []
original_shapes = []
@@ -2274,8 +2437,11 @@ def _compute_preconditioners(states, params, step):
if num_statistics > 0:
preconditioner = preconditioner_from_params(param)
for statistic in state.statistics:
- exponents.append(preconditioner.exponent_for_preconditioner(
- ) if exponent_override == 0 else exponent_override)
+ exponents.append(
+ preconditioner.exponent_for_preconditioner()
+ if exponent_override == 0
+ else exponent_override
+ )
original_shapes_for_state.append(statistic.shape)
max_size = max(max_size, statistic.shape[0])
@@ -2283,14 +2449,16 @@ def _compute_preconditioners(states, params, step):
prev_preconditioners.extend(state.preconditioners)
original_shapes.extend(original_shapes_for_state)
- return _pmap_compute_preconditioners(states,
- step,
- statistics,
- num_statistics_per_state,
- original_shapes,
- exponents,
- max_size,
- prev_preconditioners)
+ return _pmap_compute_preconditioners(
+ states,
+ step,
+ statistics,
+ num_statistics_per_state,
+ original_shapes,
+ exponents,
+ max_size,
+ prev_preconditioners,
+ )
def _transform_grad(grad, state, param, step):
"""Transform per-parameter gradients."""
@@ -2298,21 +2466,25 @@ def _transform_grad(grad, state, param, step):
sgd_update = grad
new_diagonal_statistics = state.diagonal_statistics
- if (graft_type == GraftingType.ADAGRAD or
- graft_type == GraftingType.ADAGRAD_NORMALIZED):
-
+ if (
+ graft_type == GraftingType.ADAGRAD
+ or graft_type == GraftingType.ADAGRAD_NORMALIZED
+ ):
scaled_grad = grad
if graft_type == GraftingType.ADAGRAD_NORMALIZED:
scaled_grad = grad / (jnp.linalg.norm(grad) + _EPSILON)
new_diagonal_statistics = (
- state.diagonal_statistics.to_float() + jnp.square(scaled_grad))
+ state.diagonal_statistics.to_float() + jnp.square(scaled_grad)
+ )
adagrad_update = scaled_grad / (
- jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon)
+ jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon
+ )
grafting_update = adagrad_update
- elif (graft_type == GraftingType.RMSPROP or
- graft_type == GraftingType.RMSPROP_NORMALIZED):
-
+ elif (
+ graft_type == GraftingType.RMSPROP
+ or graft_type == GraftingType.RMSPROP_NORMALIZED
+ ):
scaled_grad = grad
if graft_type == GraftingType.RMSPROP_NORMALIZED:
scaled_grad = grad / (jnp.linalg.norm(grad) + _EPSILON)
@@ -2321,15 +2493,19 @@ def _transform_grad(grad, state, param, step):
w2 = jnp.where(beta2 == 1.0, beta2, 1.0 - beta2)
new_diagonal_statistics = (
- w1 * state.diagonal_statistics + w2 * jnp.square(scaled_grad))
+ w1 * state.diagonal_statistics + w2 * jnp.square(scaled_grad)
+ )
rmsprop_update = scaled_grad / (
- jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon)
+ jnp.sqrt(new_diagonal_statistics) + diagonal_epsilon
+ )
if clip_by_scaled_gradient_norm:
scaled_grad_norm = jnp.linalg.norm(rmsprop_update) / (
- jnp.sqrt(float(rmsprop_update.size)))
+ jnp.sqrt(float(rmsprop_update.size))
+ )
clipping_denom = jnp.maximum(
- 1., scaled_grad_norm / clip_by_scaled_gradient_norm)
+ 1.0, scaled_grad_norm / clip_by_scaled_gradient_norm
+ )
rmsprop_update /= clipping_denom
grafting_update = rmsprop_update
@@ -2349,12 +2525,14 @@ def _transform_grad(grad, state, param, step):
precond_grad = grad
if not _skip_preconditioning(param):
- precond_grad = preconditioner.preconditioned_grad(precond_grad,
- state.preconditioners)
+ precond_grad = preconditioner.preconditioned_grad(
+ precond_grad, state.preconditioners
+ )
else:
if graft_type == GraftingType.NONE:
- logging.error("skipping preconditioning without grafting for param %s",
- param)
+ logging.error(
+ 'skipping preconditioning without grafting for param %s', param
+ )
precond_grad = grafting_update
grafting_update_norm = jnp.linalg.norm(grafting_update)
@@ -2369,39 +2547,49 @@ def _transform_grad(grad, state, param, step):
shampoo_update_with_wd = shampoo_update
grafting_update_with_wd = grafting_update
- if (weight_decay != 0 and weight_decay is not None and
- not decoupled_weight_decay):
+ if (
+ weight_decay != 0
+ and weight_decay is not None
+ and not decoupled_weight_decay
+ ):
shampoo_update_with_wd = shampoo_update + weight_decay * param
grafting_update_with_wd = grafting_update + weight_decay * param
w = (1.0 - beta1) if moving_average_for_momentum else 1.0
shampoo_update_with_wd_momentum = (
- state.momentum * beta1 + w * shampoo_update_with_wd)
+ state.momentum * beta1 + w * shampoo_update_with_wd
+ )
grafting_update_with_wd_momentum = (
- state.diagonal_momentum * beta1 + w * grafting_update_with_wd)
+ state.diagonal_momentum * beta1 + w * grafting_update_with_wd
+ )
run_shampoo = (step >= start_preconditioning_step).astype(
- grafting_update_with_wd_momentum.dtype)
+ grafting_update_with_wd_momentum.dtype
+ )
momentum_update = (
- run_shampoo * shampoo_update_with_wd_momentum +
- (1.0 - run_shampoo) * grafting_update_with_wd_momentum)
+ run_shampoo * shampoo_update_with_wd_momentum
+ + (1.0 - run_shampoo) * grafting_update_with_wd_momentum
+ )
wd_update = (
- run_shampoo * shampoo_update_with_wd +
- (1.0 - run_shampoo) * grafting_update_with_wd)
+ run_shampoo * shampoo_update_with_wd
+ + (1.0 - run_shampoo) * grafting_update_with_wd
+ )
nesterov_momentum_update = momentum_update
if nesterov:
nesterov_momentum_update = w * wd_update + beta1 * momentum_update
- if (weight_decay != 0 and weight_decay is not None and
- decoupled_weight_decay):
+ if (
+ weight_decay != 0 and weight_decay is not None and decoupled_weight_decay
+ ):
nesterov_momentum_update = (
- nesterov_momentum_update + lr * weight_decay * param)
+ nesterov_momentum_update + lr * weight_decay * param
+ )
momentum_multiplier = lr if decoupled_learning_rate else 1.0
transformed_update = -1.0 * momentum_multiplier * nesterov_momentum_update
@@ -2409,26 +2597,28 @@ def _transform_grad(grad, state, param, step):
new_diagonal_momentum = grafting_update_with_wd_momentum
new_momentum = shampoo_update_with_wd_momentum
- param_stats = ParameterStats(new_diagonal_statistics,
- state.statistics,
- state.preconditioners,
- new_diagonal_momentum,
- new_momentum,
- state.training_metrics)
+ param_stats = ParameterStats(
+ new_diagonal_statistics,
+ state.statistics,
+ state.preconditioners,
+ new_diagonal_momentum,
+ new_momentum,
+ state.training_metrics,
+ )
return transformed_update, param_stats
def update_fn(grads, state, params):
"""Transform the input gradient and update all statistics.
- Args:
- grads: the gradient tensors for the parameters and any custom
- gradients for preconditioners.
- state: a named tuple containing the state of the optimizer
- params: the parameters that should be updated.
+ Args:
+ grads: the gradient tensors for the parameters and any custom
+ gradients for preconditioners.
+ state: a named tuple containing the state of the optimizer
+ params: the parameters that should be updated.
- Returns:
- A tuple containing the new parameters and the new optimizer state.
- """
+ Returns:
+ A tuple containing the new parameters and the new optimizer state.
+ """
grads_custom = None
if custom_preconditioner and isinstance(grads, tuple):
grads, grads_custom = grads
@@ -2442,23 +2632,21 @@ def update_fn(grads, state, params):
stats_grads = treedef.flatten_up_to(grads_custom)
new_stats_flat = jax.tree.map(
- lambda g,
- s,
- p: _compute_stats(g, s, p, state.count),
- stats_grads,
- stats_flat,
- params_flat)
-
- new_stats_flat = _compute_preconditioners(new_stats_flat,
- params_flat,
- state.count)
+ lambda g, s, p: _compute_stats(g, s, p, state.count),
+ stats_grads,
+ stats_flat,
+ params_flat,
+ )
+
+ new_stats_flat = _compute_preconditioners(
+ new_stats_flat, params_flat, state.count
+ )
outputs = jax.tree.map(
- lambda g,
- s,
- p: _transform_grad(g, s, p, state.count),
- grads_flat,
- new_stats_flat,
- params_flat)
+ lambda g, s, p: _transform_grad(g, s, p, state.count),
+ grads_flat,
+ new_stats_flat,
+ params_flat,
+ )
updates_flat, new_stats_flat = list(zip(*outputs)) if outputs else ((), ())
updates = jax.tree_util.tree_unflatten(treedef, updates_flat)
new_stats = jax.tree_util.tree_unflatten(treedef, new_stats_flat)
@@ -2472,9 +2660,10 @@ def update_fn(grads, state, params):
def _init_fns(unused_params):
return InitFnState(
- init_fn=opt_init_fn,
- pspec_fn=sharded_init_partition_spec_fn,
- shape_and_dtype_fn=sharded_init_shape_and_dtype_fn)
+ init_fn=opt_init_fn,
+ pspec_fn=sharded_init_partition_spec_fn,
+ shape_and_dtype_fn=sharded_init_shape_and_dtype_fn,
+ )
opt_update_fn = sharded_update_fn
return optax.GradientTransformation(_init_fns, opt_update_fn)
diff --git a/reference_algorithms/paper_baselines/shampoo/jax/submission.py b/reference_algorithms/paper_baselines/shampoo/jax/submission.py
index 2cd054062..526afe7d5 100644
--- a/reference_algorithms/paper_baselines/shampoo/jax/submission.py
+++ b/reference_algorithms/paper_baselines/shampoo/jax/submission.py
@@ -10,17 +10,20 @@
import optax
from algoperf import spec
-from reference_algorithms.paper_baselines.shampoo.jax.distributed_shampoo import \
- distributed_shampoo
+from reference_algorithms.paper_baselines.shampoo.jax.distributed_shampoo import (
+ distributed_shampoo,
+)
_GRAD_CLIP_EPS = 1e-6
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a Shampoo optimizer and a learning rate schedule."""
del model_params
del model_state
@@ -30,102 +33,116 @@ def jax_cosine_warmup(step_hint: int, hyperparameters):
# Create learning rate schedule.
warmup_steps = int(hyperparameters.warmup_factor * step_hint)
warmup_fn = optax.linear_schedule(
- init_value=0.,
- end_value=hyperparameters.learning_rate,
- transition_steps=warmup_steps)
+ init_value=0.0,
+ end_value=hyperparameters.learning_rate,
+ transition_steps=warmup_steps,
+ )
cosine_steps = max(step_hint - warmup_steps, 1)
cosine_fn = optax.cosine_decay_schedule(
- init_value=hyperparameters.learning_rate, decay_steps=cosine_steps)
+ init_value=hyperparameters.learning_rate, decay_steps=cosine_steps
+ )
schedule_fn = optax.join_schedules(
- schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps])
+ schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]
+ )
return schedule_fn
# Create optimizer + LR schedule.
lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters)
opt_init_fn, opt_update_fn = distributed_shampoo(
- learning_rate=lr_schedule_fn,
- beta1=1.0 - hyperparameters.one_minus_beta1,
- beta2=hyperparameters.beta2,
- weight_decay=hyperparameters.weight_decay,
- batch_axis_name='batch',
- eigh=False)
- params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple),
- workload.param_shapes)
+ learning_rate=lr_schedule_fn,
+ beta1=1.0 - hyperparameters.one_minus_beta1,
+ beta2=hyperparameters.beta2,
+ weight_decay=hyperparameters.weight_decay,
+ batch_axis_name='batch',
+ eigh=False,
+ )
+ params_zeros_like = jax.tree.map(
+ lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes
+ )
optimizer_state = opt_init_fn(params_zeros_like)
return jax_utils.replicate(optimizer_state), opt_update_fn
@functools.partial(
- jax.pmap,
- axis_name='batch',
- in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
- static_broadcasted_argnums=(0, 1),
- donate_argnums=(2, 3, 4))
-def pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- batch,
- rng,
- grad_clip,
- label_smoothing):
-
+ jax.pmap,
+ axis_name='batch',
+ in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
+ static_broadcasted_argnums=(0, 1),
+ donate_argnums=(2, 3, 4),
+)
+def pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ rng,
+ grad_clip,
+ label_smoothing,
+):
def _loss_fn(params):
"""Loss function used for training."""
logits, new_model_state = workload.model_fn(
- params,
- batch,
- model_state,
- spec.ForwardPassMode.TRAIN,
- rng,
- update_batch_norm=True)
+ params,
+ batch,
+ model_state,
+ spec.ForwardPassMode.TRAIN,
+ rng,
+ update_batch_norm=True,
+ )
loss_dict = workload.loss_fn(
- label_batch=batch['targets'],
- logits_batch=logits,
- mask_batch=batch.get('weights'),
- label_smoothing=label_smoothing)
+ label_batch=batch['targets'],
+ logits_batch=logits,
+ mask_batch=batch.get('weights'),
+ label_smoothing=label_smoothing,
+ )
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
return summed_loss, (n_valid_examples, new_model_state)
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
(summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn(
- current_param_container)
+ current_param_container
+ )
# Get correct global mean loss and grad.
(summed_loss, n_valid_examples, grad) = lax.psum(
- (summed_loss, n_valid_examples, grad), axis_name='batch')
+ (summed_loss, n_valid_examples, grad), axis_name='batch'
+ )
loss = summed_loss / n_valid_examples
grad = jax.tree.map(lambda x: x / n_valid_examples, grad)
grad_norm = jnp.sqrt(
- sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)))
+ sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))
+ )
if grad_clip is not None:
grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS)
grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0)
grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad)
- updates, new_optimizer_state = opt_update_fn(grad, optimizer_state,
- current_param_container)
+ updates, new_optimizer_state = opt_update_fn(
+ grad, optimizer_state, current_param_container
+ )
updated_params = optax.apply_updates(current_param_container, updates)
return new_optimizer_state, updated_params, new_model_state, loss, grad_norm
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -142,37 +159,43 @@ def update_params(
grad_clip = hyperparameters.grad_clip
else:
grad_clip = None
- outputs = pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- batch,
- per_device_rngs,
- grad_clip,
- label_smoothing)
+ outputs = pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ per_device_rngs,
+ grad_clip,
+ label_smoothing,
+ )
new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs
# Log loss, grad_norm.
if global_step % 100 == 0 and workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
- {
- 'loss': loss[0],
- 'grad_norm': grad_norm[0],
- }, global_step)
+ {
+ 'loss': loss[0],
+ 'grad_norm': grad_norm[0],
+ },
+ global_step,
+ )
return (new_optimizer_state, opt_update_fn), new_params, new_model_state
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -208,14 +231,16 @@ def get_batch_size(workload_name):
raise ValueError(f'Unsupported workload name: {workload_name}.')
-def data_selection(workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+def data_selection(
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
diff --git a/reference_algorithms/target_setting_algorithms/cosine_warmup.py b/reference_algorithms/target_setting_algorithms/cosine_warmup.py
index 116ebc555..eeb87cd87 100644
--- a/reference_algorithms/target_setting_algorithms/cosine_warmup.py
+++ b/reference_algorithms/target_setting_algorithms/cosine_warmup.py
@@ -9,27 +9,31 @@
def jax_cosine_warmup(step_hint: int, hyperparameters):
# Create learning rate schedule.
warmup_fn = optax.linear_schedule(
- init_value=0.,
- end_value=hyperparameters.learning_rate,
- transition_steps=hyperparameters.warmup_steps)
+ init_value=0.0,
+ end_value=hyperparameters.learning_rate,
+ transition_steps=hyperparameters.warmup_steps,
+ )
cosine_steps = max(step_hint - hyperparameters.warmup_steps, 1)
cosine_fn = optax.cosine_decay_schedule(
- init_value=hyperparameters.learning_rate, decay_steps=cosine_steps)
+ init_value=hyperparameters.learning_rate, decay_steps=cosine_steps
+ )
schedule_fn = optax.join_schedules(
- schedules=[warmup_fn, cosine_fn],
- boundaries=[hyperparameters.warmup_steps])
+ schedules=[warmup_fn, cosine_fn], boundaries=[hyperparameters.warmup_steps]
+ )
return schedule_fn
def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer):
warmup = LinearLR(
- optimizer,
- start_factor=1e-10,
- end_factor=1.,
- total_iters=hyperparameters.warmup_steps)
+ optimizer,
+ start_factor=1e-10,
+ end_factor=1.0,
+ total_iters=hyperparameters.warmup_steps,
+ )
cosine_steps = max(step_hint - hyperparameters.warmup_steps, 1)
cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps)
return SequentialLR(
- optimizer,
- schedulers=[warmup, cosine_decay],
- milestones=[hyperparameters.warmup_steps])
+ optimizer,
+ schedulers=[warmup, cosine_decay],
+ milestones=[hyperparameters.warmup_steps],
+ )
diff --git a/reference_algorithms/target_setting_algorithms/criteo1tb/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/criteo1tb/tuning_search_space.json
index cab6fd5f7..6061940a9 100644
--- a/reference_algorithms/target_setting_algorithms/criteo1tb/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/criteo1tb/tuning_search_space.json
@@ -1,27 +1,17 @@
{
"learning_rate": {
- "feasible_points": [
- 0.0033313215673016375
- ]
+ "feasible_points": [0.0033313215673016375]
},
"beta1": {
- "feasible_points": [
- 0.948000082541717
- ]
+ "feasible_points": [0.948000082541717]
},
"beta2": {
- "feasible_points": [
- 0.9987934318891598
- ]
+ "feasible_points": [0.9987934318891598]
},
"warmup_steps": {
- "feasible_points": [
- 159
- ]
+ "feasible_points": [159]
},
"weight_decay": {
- "feasible_points": [
- 0.0035784380304876183
- ]
+ "feasible_points": [0.0035784380304876183]
}
}
diff --git a/reference_algorithms/target_setting_algorithms/data_selection.py b/reference_algorithms/target_setting_algorithms/data_selection.py
index 5e70f9f8b..e0d9c0ee9 100644
--- a/reference_algorithms/target_setting_algorithms/data_selection.py
+++ b/reference_algorithms/target_setting_algorithms/data_selection.py
@@ -4,14 +4,15 @@
def data_selection(
- workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Tuple[spec.Tensor, spec.Tensor, spec.Tensor]:
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Tuple[spec.Tensor, spec.Tensor, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
diff --git a/reference_algorithms/target_setting_algorithms/jax_adamw.py b/reference_algorithms/target_setting_algorithms/jax_adamw.py
index b64f0dfd6..e6f8d915c 100644
--- a/reference_algorithms/target_setting_algorithms/jax_adamw.py
+++ b/reference_algorithms/target_setting_algorithms/jax_adamw.py
@@ -1,4 +1,5 @@
"""Submission file for an AdamW optimizer with warmup+cosine LR in Jax."""
+
from flax import jax_utils
import jax
import jax.numpy as jnp
@@ -6,39 +7,48 @@
from algoperf import spec
from reference_algorithms.target_setting_algorithms import cosine_warmup
-from reference_algorithms.target_setting_algorithms.data_selection import \
- data_selection # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.get_batch_size import \
- get_batch_size # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.jax_submission_base import \
- update_params # pylint: disable=unused-import
+from reference_algorithms.target_setting_algorithms.data_selection import (
+ data_selection,
+) # pylint: disable=unused-import
+from reference_algorithms.target_setting_algorithms.get_batch_size import (
+ get_batch_size,
+) # pylint: disable=unused-import
+from reference_algorithms.target_setting_algorithms.jax_submission_base import (
+ update_params,
+) # pylint: disable=unused-import
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates an AdamW optimizer and a learning rate schedule."""
del model_params
del model_state
del rng
target_setting_step_hint = int(0.75 * workload.step_hint)
- lr_schedule_fn = cosine_warmup.jax_cosine_warmup(target_setting_step_hint,
- hyperparameters)
+ lr_schedule_fn = cosine_warmup.jax_cosine_warmup(
+ target_setting_step_hint, hyperparameters
+ )
# Create optimizer.
- params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple),
- workload.param_shapes)
+ params_zeros_like = jax.tree.map(
+ lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes
+ )
epsilon = (
- hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8)
+ hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8
+ )
opt_init_fn, opt_update_fn = optax.adamw(
- learning_rate=lr_schedule_fn,
- b1=hyperparameters.beta1,
- b2=hyperparameters.beta2,
- eps=epsilon,
- weight_decay=hyperparameters.weight_decay)
+ learning_rate=lr_schedule_fn,
+ b1=hyperparameters.beta1,
+ b2=hyperparameters.beta2,
+ eps=epsilon,
+ weight_decay=hyperparameters.weight_decay,
+ )
optimizer_state = opt_init_fn(params_zeros_like)
return jax_utils.replicate(optimizer_state), opt_update_fn
diff --git a/reference_algorithms/target_setting_algorithms/jax_momentum.py b/reference_algorithms/target_setting_algorithms/jax_momentum.py
index a6c3d853b..b37464c1c 100644
--- a/reference_algorithms/target_setting_algorithms/jax_momentum.py
+++ b/reference_algorithms/target_setting_algorithms/jax_momentum.py
@@ -8,19 +8,24 @@
import optax
from algoperf import spec
-from reference_algorithms.target_setting_algorithms.data_selection import \
- data_selection # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.get_batch_size import \
- get_batch_size # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.jax_submission_base import \
- update_params # pylint: disable=unused-import
+from reference_algorithms.target_setting_algorithms.data_selection import (
+ data_selection,
+) # pylint: disable=unused-import
+from reference_algorithms.target_setting_algorithms.get_batch_size import (
+ get_batch_size,
+) # pylint: disable=unused-import
+from reference_algorithms.target_setting_algorithms.jax_submission_base import (
+ update_params,
+) # pylint: disable=unused-import
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a Nesterov optimizer and a learning rate schedule."""
del model_params
del model_state
@@ -28,38 +33,44 @@ def init_optimizer_state(workload: spec.Workload,
# Create learning rate schedule.
target_setting_step_hint = int(0.75 * workload.step_hint)
- lr_schedule_fn = create_lr_schedule_fn(target_setting_step_hint,
- hyperparameters)
+ lr_schedule_fn = create_lr_schedule_fn(
+ target_setting_step_hint, hyperparameters
+ )
# Create optimizer.
- params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple),
- workload.param_shapes)
+ params_zeros_like = jax.tree.map(
+ lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes
+ )
opt_init_fn, opt_update_fn = sgd(
- learning_rate=lr_schedule_fn,
- weight_decay=hyperparameters.weight_decay,
- momentum=hyperparameters.beta1,
- nesterov=False)
+ learning_rate=lr_schedule_fn,
+ weight_decay=hyperparameters.weight_decay,
+ momentum=hyperparameters.beta1,
+ nesterov=False,
+ )
optimizer_state = opt_init_fn(params_zeros_like)
return jax_utils.replicate(optimizer_state), opt_update_fn
def create_lr_schedule_fn(
- step_hint: int,
- hyperparameters: spec.Hyperparameters) -> Callable[[int], float]:
+ step_hint: int, hyperparameters: spec.Hyperparameters
+) -> Callable[[int], float]:
warmup_fn = optax.linear_schedule(
- init_value=0.,
- end_value=hyperparameters.learning_rate,
- transition_steps=hyperparameters.warmup_steps)
+ init_value=0.0,
+ end_value=hyperparameters.learning_rate,
+ transition_steps=hyperparameters.warmup_steps,
+ )
decay_steps = step_hint - hyperparameters.warmup_steps
polynomial_schedule_fn = optax.polynomial_schedule(
- init_value=hyperparameters.learning_rate,
- end_value=hyperparameters.learning_rate * hyperparameters.end_factor,
- power=1,
- transition_steps=int(decay_steps * hyperparameters.decay_steps_factor))
+ init_value=hyperparameters.learning_rate,
+ end_value=hyperparameters.learning_rate * hyperparameters.end_factor,
+ power=1,
+ transition_steps=int(decay_steps * hyperparameters.decay_steps_factor),
+ )
lr_schedule_fn = optax.join_schedules(
- schedules=[warmup_fn, polynomial_schedule_fn],
- boundaries=[hyperparameters.warmup_steps])
+ schedules=[warmup_fn, polynomial_schedule_fn],
+ boundaries=[hyperparameters.warmup_steps],
+ )
return lr_schedule_fn
@@ -85,6 +96,8 @@ def sgd(learning_rate, weight_decay, momentum=None, nesterov=False):
{SGD, Momentum, Nesterov} update.
"""
return optax.chain(
- optax.add_decayed_weights(weight_decay),
- optax.sgd(
- learning_rate=learning_rate, momentum=momentum, nesterov=nesterov))
+ optax.add_decayed_weights(weight_decay),
+ optax.sgd(
+ learning_rate=learning_rate, momentum=momentum, nesterov=nesterov
+ ),
+ )
diff --git a/reference_algorithms/target_setting_algorithms/jax_nadamw.py b/reference_algorithms/target_setting_algorithms/jax_nadamw.py
index 597a43c9e..7d4a1fcc1 100644
--- a/reference_algorithms/target_setting_algorithms/jax_nadamw.py
+++ b/reference_algorithms/target_setting_algorithms/jax_nadamw.py
@@ -10,26 +10,28 @@
from algoperf import spec
from reference_algorithms.target_setting_algorithms import cosine_warmup
-from reference_algorithms.target_setting_algorithms.data_selection import \
- data_selection # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.get_batch_size import \
- get_batch_size # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.jax_submission_base import \
- update_params # pylint: disable=unused-import
+from reference_algorithms.target_setting_algorithms.data_selection import (
+ data_selection,
+) # pylint: disable=unused-import
+from reference_algorithms.target_setting_algorithms.get_batch_size import (
+ get_batch_size,
+) # pylint: disable=unused-import
+from reference_algorithms.target_setting_algorithms.jax_submission_base import (
+ update_params,
+) # pylint: disable=unused-import
# Forked from
# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py
def nadamw(
- learning_rate: Union[float, optax.Schedule],
- b1: float = 0.9,
- b2: float = 0.999,
- eps: float = 1e-8,
- eps_root: float = 0.0,
- debias: bool = True,
- weight_decay: float = 0.0,
- weight_decay_mask: Optional[Union[Any, Callable[[optax.Params],
- Any]]] = None,
+ learning_rate: Union[float, optax.Schedule],
+ b1: float = 0.9,
+ b2: float = 0.999,
+ eps: float = 1e-8,
+ eps_root: float = 0.0,
+ debias: bool = True,
+ weight_decay: float = 0.0,
+ weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], Any]]] = None,
) -> optax.GradientTransformation:
"""Rescale updates according to the NAdam algorithm.
References:
@@ -61,19 +63,22 @@ def nadamw(
An (init_fn, update_fn) tuple.
"""
return optax.chain(
- scale_by_nadam(b1, b2, eps, eps_root, debias),
- optax.add_decayed_weights(weight_decay, weight_decay_mask),
- scale_by_learning_rate(learning_rate))
+ scale_by_nadam(b1, b2, eps, eps_root, debias),
+ optax.add_decayed_weights(weight_decay, weight_decay_mask),
+ scale_by_learning_rate(learning_rate),
+ )
# All functions below are forked from
# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py
-def scale_by_nadam(b1: float = 0.9,
- b2: float = 0.999,
- eps: float = 1e-8,
- eps_root: float = 0.0,
- debias: bool = True,
- power: float = 0.5) -> optax.GradientTransformation:
+def scale_by_nadam(
+ b1: float = 0.9,
+ b2: float = 0.999,
+ eps: float = 1e-8,
+ eps_root: float = 0.0,
+ debias: bool = True,
+ power: float = 0.5,
+) -> optax.GradientTransformation:
"""Rescale updates according to the NAdam algorithm.
References:
There seem to be multiple versions of NAdam. The original version is here
@@ -109,7 +114,8 @@ def update_fn(updates, state, params=None):
mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count)
nu_hat = nu if not debias else _bias_correction(nu, b2, count)
updates = jax.tree.map(
- lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat)
+ lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat
+ )
return updates, ScaleByAdamState(count=count, mu=mu, nu=nu)
return optax.GradientTransformation(init_fn, update_fn)
@@ -117,6 +123,7 @@ def update_fn(updates, state, params=None):
class ScaleByAdamState(NamedTuple):
"""State for the NAdam algorithm."""
+
count: chex.Array # shape=(), dtype=jnp.int32.
mu: optax.Updates
nu: optax.Updates
@@ -125,7 +132,8 @@ class ScaleByAdamState(NamedTuple):
def _update_moment(updates, moments, decay, order):
"""Compute the exponential moving average of the `order-th` moment."""
return jax.tree.map(
- lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments)
+ lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments
+ )
def _bias_correction(moment, decay, count):
@@ -141,31 +149,37 @@ def scale_by_learning_rate(learning_rate, flip_sign=True):
return optax.scale(m * learning_rate)
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a NAdamW optimizer and a learning rate schedule."""
del model_params
del model_state
del rng
target_setting_step_hint = int(0.75 * workload.step_hint)
- lr_schedule_fn = cosine_warmup.jax_cosine_warmup(target_setting_step_hint,
- hyperparameters)
+ lr_schedule_fn = cosine_warmup.jax_cosine_warmup(
+ target_setting_step_hint, hyperparameters
+ )
# Create optimizer.
- params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple),
- workload.param_shapes)
+ params_zeros_like = jax.tree.map(
+ lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes
+ )
epsilon = (
- hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8)
+ hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8
+ )
opt_init_fn, opt_update_fn = nadamw(
- learning_rate=lr_schedule_fn,
- b1=hyperparameters.beta1,
- b2=hyperparameters.beta2,
- eps=epsilon,
- weight_decay=hyperparameters.weight_decay)
+ learning_rate=lr_schedule_fn,
+ b1=hyperparameters.beta1,
+ b2=hyperparameters.beta2,
+ eps=epsilon,
+ weight_decay=hyperparameters.weight_decay,
+ )
optimizer_state = opt_init_fn(params_zeros_like)
return jax_utils.replicate(optimizer_state), opt_update_fn
diff --git a/reference_algorithms/target_setting_algorithms/jax_nesterov.py b/reference_algorithms/target_setting_algorithms/jax_nesterov.py
index 0c11044fc..714cbb225 100644
--- a/reference_algorithms/target_setting_algorithms/jax_nesterov.py
+++ b/reference_algorithms/target_setting_algorithms/jax_nesterov.py
@@ -8,19 +8,24 @@
import optax
from algoperf import spec
-from reference_algorithms.target_setting_algorithms.data_selection import \
- data_selection # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.get_batch_size import \
- get_batch_size # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.jax_submission_base import \
- update_params # pylint: disable=unused-import
+from reference_algorithms.target_setting_algorithms.data_selection import (
+ data_selection,
+) # pylint: disable=unused-import
+from reference_algorithms.target_setting_algorithms.get_batch_size import (
+ get_batch_size,
+) # pylint: disable=unused-import
+from reference_algorithms.target_setting_algorithms.jax_submission_base import (
+ update_params,
+) # pylint: disable=unused-import
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a Nesterov optimizer and a learning rate schedule."""
del model_params
del model_state
@@ -28,38 +33,44 @@ def init_optimizer_state(workload: spec.Workload,
# Create learning rate schedule.
target_setting_step_hint = int(0.75 * workload.step_hint)
- lr_schedule_fn = create_lr_schedule_fn(target_setting_step_hint,
- hyperparameters)
+ lr_schedule_fn = create_lr_schedule_fn(
+ target_setting_step_hint, hyperparameters
+ )
# Create optimizer.
- params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple),
- workload.param_shapes)
+ params_zeros_like = jax.tree.map(
+ lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes
+ )
opt_init_fn, opt_update_fn = sgd(
- learning_rate=lr_schedule_fn,
- weight_decay=hyperparameters.weight_decay,
- momentum=hyperparameters.beta1,
- nesterov=True)
+ learning_rate=lr_schedule_fn,
+ weight_decay=hyperparameters.weight_decay,
+ momentum=hyperparameters.beta1,
+ nesterov=True,
+ )
optimizer_state = opt_init_fn(params_zeros_like)
return jax_utils.replicate(optimizer_state), opt_update_fn
def create_lr_schedule_fn(
- step_hint: int,
- hyperparameters: spec.Hyperparameters) -> Callable[[int], float]:
+ step_hint: int, hyperparameters: spec.Hyperparameters
+) -> Callable[[int], float]:
warmup_fn = optax.linear_schedule(
- init_value=0.,
- end_value=hyperparameters.learning_rate,
- transition_steps=hyperparameters.warmup_steps)
+ init_value=0.0,
+ end_value=hyperparameters.learning_rate,
+ transition_steps=hyperparameters.warmup_steps,
+ )
decay_steps = step_hint - hyperparameters.warmup_steps
polynomial_schedule_fn = optax.polynomial_schedule(
- init_value=hyperparameters.learning_rate,
- end_value=hyperparameters.learning_rate * hyperparameters.end_factor,
- power=1,
- transition_steps=int(decay_steps * hyperparameters.decay_steps_factor))
+ init_value=hyperparameters.learning_rate,
+ end_value=hyperparameters.learning_rate * hyperparameters.end_factor,
+ power=1,
+ transition_steps=int(decay_steps * hyperparameters.decay_steps_factor),
+ )
lr_schedule_fn = optax.join_schedules(
- schedules=[warmup_fn, polynomial_schedule_fn],
- boundaries=[hyperparameters.warmup_steps])
+ schedules=[warmup_fn, polynomial_schedule_fn],
+ boundaries=[hyperparameters.warmup_steps],
+ )
return lr_schedule_fn
@@ -85,6 +96,8 @@ def sgd(learning_rate, weight_decay, momentum=None, nesterov=False):
{SGD, Momentum, Nesterov} update.
"""
return optax.chain(
- optax.add_decayed_weights(weight_decay),
- optax.sgd(
- learning_rate=learning_rate, momentum=momentum, nesterov=nesterov))
+ optax.add_decayed_weights(weight_decay),
+ optax.sgd(
+ learning_rate=learning_rate, momentum=momentum, nesterov=nesterov
+ ),
+ )
diff --git a/reference_algorithms/target_setting_algorithms/jax_submission_base.py b/reference_algorithms/target_setting_algorithms/jax_submission_base.py
index 217228935..557e6957c 100644
--- a/reference_algorithms/target_setting_algorithms/jax_submission_base.py
+++ b/reference_algorithms/target_setting_algorithms/jax_submission_base.py
@@ -1,4 +1,5 @@
"""Update submission function in Jax."""
+
import functools
from typing import Any, Dict, List, Optional, Tuple
@@ -13,75 +14,84 @@
@functools.partial(
- jax.pmap,
- axis_name='batch',
- in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
- static_broadcasted_argnums=(0, 1),
- donate_argnums=(2, 3, 4))
-def pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- batch,
- rng,
- grad_clip,
- label_smoothing):
-
+ jax.pmap,
+ axis_name='batch',
+ in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
+ static_broadcasted_argnums=(0, 1),
+ donate_argnums=(2, 3, 4),
+)
+def pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ rng,
+ grad_clip,
+ label_smoothing,
+):
def _loss_fn(params):
"""Loss function used for training."""
logits, new_model_state = workload.model_fn(
- params,
- batch,
- model_state,
- spec.ForwardPassMode.TRAIN,
- rng,
- update_batch_norm=True)
+ params,
+ batch,
+ model_state,
+ spec.ForwardPassMode.TRAIN,
+ rng,
+ update_batch_norm=True,
+ )
loss_dict = workload.loss_fn(
- label_batch=batch['targets'],
- logits_batch=logits,
- mask_batch=batch.get('weights'),
- label_smoothing=label_smoothing)
+ label_batch=batch['targets'],
+ logits_batch=logits,
+ mask_batch=batch.get('weights'),
+ label_smoothing=label_smoothing,
+ )
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
return summed_loss, (n_valid_examples, new_model_state)
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
(summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn(
- current_param_container)
+ current_param_container
+ )
# Get correct global mean loss and grad.
(summed_loss, n_valid_examples, grad) = lax.psum(
- (summed_loss, n_valid_examples, grad), axis_name='batch')
+ (summed_loss, n_valid_examples, grad), axis_name='batch'
+ )
loss = summed_loss / n_valid_examples
grad = jax.tree.map(lambda x: x / n_valid_examples, grad)
grad_norm = jnp.sqrt(
- sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)))
+ sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))
+ )
if grad_clip is not None:
grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS)
grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0)
grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad)
- updates, new_optimizer_state = opt_update_fn(grad, optimizer_state,
- current_param_container)
+ updates, new_optimizer_state = opt_update_fn(
+ grad, optimizer_state, current_param_container
+ )
updated_params = optax.apply_updates(current_param_container, updates)
return new_optimizer_state, updated_params, new_model_state, loss, grad_norm
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -98,32 +108,46 @@ def update_params(
grad_clip = hyperparameters.grad_clip
else:
grad_clip = None
- new_optimizer_state, new_params, new_model_state, loss, grad_norm = pmapped_train_step( # pylint: disable=line-too-long
- workload, opt_update_fn, model_state, optimizer_state,
- current_param_container, batch, per_device_rngs, grad_clip,
- label_smoothing)
+ new_optimizer_state, new_params, new_model_state, loss, grad_norm = (
+ pmapped_train_step( # pylint: disable=line-too-long
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ per_device_rngs,
+ grad_clip,
+ label_smoothing,
+ )
+ )
# Log loss, grad_norm.
- if ((global_step <= 100 or global_step % 500 == 0) and
- workload.metrics_logger is not None):
+ if (
+ global_step <= 100 or global_step % 500 == 0
+ ) and workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
- {
- 'loss': loss[0],
- 'grad_norm': grad_norm[0],
- }, global_step)
+ {
+ 'loss': loss[0],
+ 'grad_norm': grad_norm[0],
+ },
+ global_step,
+ )
return (new_optimizer_state, opt_update_fn), new_params, new_model_state
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
diff --git a/reference_algorithms/target_setting_algorithms/pytorch_adamw.py b/reference_algorithms/target_setting_algorithms/pytorch_adamw.py
index c87bdfb7d..f2474a706 100644
--- a/reference_algorithms/target_setting_algorithms/pytorch_adamw.py
+++ b/reference_algorithms/target_setting_algorithms/pytorch_adamw.py
@@ -4,37 +4,44 @@
from algoperf import spec
from reference_algorithms.target_setting_algorithms import cosine_warmup
-from reference_algorithms.target_setting_algorithms.data_selection import \
- data_selection # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.get_batch_size import \
- get_batch_size # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.pytorch_submission_base import \
- update_params # pylint: disable=unused-import
+from reference_algorithms.target_setting_algorithms.data_selection import (
+ data_selection,
+) # pylint: disable=unused-import
+from reference_algorithms.target_setting_algorithms.get_batch_size import (
+ get_batch_size,
+) # pylint: disable=unused-import
+from reference_algorithms.target_setting_algorithms.pytorch_submission_base import (
+ update_params,
+) # pylint: disable=unused-import
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates an AdamW optimizer and a learning rate schedule."""
del model_state
del rng
epsilon = (
- hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8)
+ hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8
+ )
optimizer_state = {
- 'optimizer':
- torch.optim.AdamW(
- model_params.parameters(),
- lr=hyperparameters.learning_rate,
- betas=(hyperparameters.beta1, hyperparameters.beta2),
- eps=epsilon,
- weight_decay=hyperparameters.weight_decay),
+ 'optimizer': torch.optim.AdamW(
+ model_params.parameters(),
+ lr=hyperparameters.learning_rate,
+ betas=(hyperparameters.beta1, hyperparameters.beta2),
+ eps=epsilon,
+ weight_decay=hyperparameters.weight_decay,
+ ),
}
target_setting_step_hint = int(0.75 * workload.step_hint)
optimizer_state['scheduler'] = cosine_warmup.pytorch_cosine_warmup(
- target_setting_step_hint, hyperparameters, optimizer_state['optimizer'])
+ target_setting_step_hint, hyperparameters, optimizer_state['optimizer']
+ )
return optimizer_state
diff --git a/reference_algorithms/target_setting_algorithms/pytorch_momentum.py b/reference_algorithms/target_setting_algorithms/pytorch_momentum.py
index 584caff39..030939de5 100644
--- a/reference_algorithms/target_setting_algorithms/pytorch_momentum.py
+++ b/reference_algorithms/target_setting_algorithms/pytorch_momentum.py
@@ -4,40 +4,47 @@
from torch.optim.lr_scheduler import LambdaLR
from algoperf import spec
-from reference_algorithms.target_setting_algorithms.data_selection import \
- data_selection # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.get_batch_size import \
- get_batch_size # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.jax_momentum import \
- create_lr_schedule_fn
-from reference_algorithms.target_setting_algorithms.pytorch_submission_base import \
- update_params # pylint: disable=unused-import
-
-
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+from reference_algorithms.target_setting_algorithms.data_selection import (
+ data_selection,
+) # pylint: disable=unused-import
+from reference_algorithms.target_setting_algorithms.get_batch_size import (
+ get_batch_size,
+) # pylint: disable=unused-import
+from reference_algorithms.target_setting_algorithms.jax_momentum import (
+ create_lr_schedule_fn,
+)
+from reference_algorithms.target_setting_algorithms.pytorch_submission_base import (
+ update_params,
+) # pylint: disable=unused-import
+
+
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a Nesterov optimizer and a learning rate schedule."""
del model_state
del rng
# Create optimizer.
optimizer_state = {
- 'optimizer':
- torch.optim.SGD(
- model_params.parameters(),
- lr=hyperparameters.learning_rate,
- momentum=hyperparameters.beta1,
- weight_decay=hyperparameters.weight_decay,
- nesterov=False),
+ 'optimizer': torch.optim.SGD(
+ model_params.parameters(),
+ lr=hyperparameters.learning_rate,
+ momentum=hyperparameters.beta1,
+ weight_decay=hyperparameters.weight_decay,
+ nesterov=False,
+ ),
}
# Create learning rate schedule.
target_setting_step_hint = int(0.75 * workload.step_hint)
- lr_schedule_fn = create_lr_schedule_fn(target_setting_step_hint,
- hyperparameters)
+ lr_schedule_fn = create_lr_schedule_fn(
+ target_setting_step_hint, hyperparameters
+ )
# PyTorch's LambdaLR expects the lr_lambda fn to return a factor which will
# be multiplied with the base lr, so we have to divide by it here.
@@ -45,6 +52,7 @@ def _lr_lambda(step: int) -> float:
return lr_schedule_fn(step).item() / hyperparameters.learning_rate
optimizer_state['scheduler'] = LambdaLR(
- optimizer_state['optimizer'], lr_lambda=_lr_lambda)
+ optimizer_state['optimizer'], lr_lambda=_lr_lambda
+ )
return optimizer_state
diff --git a/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py b/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py
index a9dee1d79..ceeebda6d 100644
--- a/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py
+++ b/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py
@@ -8,43 +8,43 @@
from algoperf import spec
from reference_algorithms.target_setting_algorithms import cosine_warmup
-from reference_algorithms.target_setting_algorithms.data_selection import \
- data_selection # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.get_batch_size import \
- get_batch_size # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.pytorch_submission_base import \
- update_params # pylint: disable=unused-import
+from reference_algorithms.target_setting_algorithms.data_selection import (
+ data_selection,
+) # pylint: disable=unused-import
+from reference_algorithms.target_setting_algorithms.get_batch_size import (
+ get_batch_size,
+) # pylint: disable=unused-import
+from reference_algorithms.target_setting_algorithms.pytorch_submission_base import (
+ update_params,
+) # pylint: disable=unused-import
# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py
class NAdamW(torch.optim.Optimizer):
r"""Implements NAdamW algorithm.
- See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of
- the NAdam algorithm (there is also a comment in the code which highlights
- the only difference of NAdamW and AdamW).
- For further details regarding the algorithm we refer to
- `Decoupled Weight Decay Regularization`_.
- Args:
- params (iterable): iterable of parameters to optimize or dicts defining
- parameter groups
- lr (float, optional): learning rate (default: 1e-3)
- betas (Tuple[float, float], optional): coefficients used for computing
- running averages of gradient and its square (default: (0.9, 0.999))
- eps (float, optional): term added to the denominator to improve
- numerical stability (default: 1e-8)
- weight_decay (float, optional): weight decay coefficient (default: 1e-2)
- .. _Decoupled Weight Decay Regularization:
- https://arxiv.org/abs/1711.05101
- .. _On the Convergence of Adam and Beyond:
- https://openreview.net/forum?id=ryQu7f-RZ
- """
-
- def __init__(self,
- params,
- lr=1e-3,
- betas=(0.9, 0.999),
- eps=1e-8,
- weight_decay=1e-2):
+ See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of
+ the NAdam algorithm (there is also a comment in the code which highlights
+ the only difference of NAdamW and AdamW).
+ For further details regarding the algorithm we refer to
+ `Decoupled Weight Decay Regularization`_.
+ Args:
+ params (iterable): iterable of parameters to optimize or dicts defining
+ parameter groups
+ lr (float, optional): learning rate (default: 1e-3)
+ betas (Tuple[float, float], optional): coefficients used for computing
+ running averages of gradient and its square (default: (0.9, 0.999))
+ eps (float, optional): term added to the denominator to improve
+ numerical stability (default: 1e-8)
+ weight_decay (float, optional): weight decay coefficient (default: 1e-2)
+ .. _Decoupled Weight Decay Regularization:
+ https://arxiv.org/abs/1711.05101
+ .. _On the Convergence of Adam and Beyond:
+ https://openreview.net/forum?id=ryQu7f-RZ
+ """
+
+ def __init__(
+ self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2
+ ):
if not 0.0 <= lr:
raise ValueError(f'Invalid learning rate: {lr}')
if not 0.0 <= eps:
@@ -56,7 +56,10 @@ def __init__(self,
if not 0.0 <= weight_decay:
raise ValueError(f'Invalid weight_decay value: {weight_decay}')
defaults = {
- 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay
+ 'lr': lr,
+ 'betas': betas,
+ 'eps': eps,
+ 'weight_decay': weight_decay,
}
super().__init__(params, defaults)
@@ -64,7 +67,8 @@ def __setstate__(self, state):
super().__setstate__(state)
state_values = list(self.state.values())
step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
- state_values[0]['step'])
+ state_values[0]['step']
+ )
if not step_is_tensor:
for s in state_values:
s['step'] = torch.tensor(float(s['step']))
@@ -72,10 +76,10 @@ def __setstate__(self, state):
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
- Args:
- closure (callable, optional): A closure that reevaluates the model
- and returns the loss.
- """
+ Args:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
self._cuda_graph_capture_health_check()
loss = None
@@ -103,51 +107,57 @@ def step(self, closure=None):
# State initialization
if len(state) == 0:
- state['step'] = torch.tensor(0.)
+ state['step'] = torch.tensor(0.0)
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(
- p, memory_format=torch.preserve_format)
+ p, memory_format=torch.preserve_format
+ )
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(
- p, memory_format=torch.preserve_format)
+ p, memory_format=torch.preserve_format
+ )
exp_avgs.append(state['exp_avg'])
exp_avg_sqs.append(state['exp_avg_sq'])
state_steps.append(state['step'])
nadamw(
- params_with_grad,
- grads,
- exp_avgs,
- exp_avg_sqs,
- state_steps,
- beta1=beta1,
- beta2=beta2,
- lr=group['lr'],
- weight_decay=group['weight_decay'],
- eps=group['eps'])
+ params_with_grad,
+ grads,
+ exp_avgs,
+ exp_avg_sqs,
+ state_steps,
+ beta1=beta1,
+ beta2=beta2,
+ lr=group['lr'],
+ weight_decay=group['weight_decay'],
+ eps=group['eps'],
+ )
return loss
-def nadamw(params: List[Tensor],
- grads: List[Tensor],
- exp_avgs: List[Tensor],
- exp_avg_sqs: List[Tensor],
- state_steps: List[Tensor],
- beta1: float,
- beta2: float,
- lr: float,
- weight_decay: float,
- eps: float):
+def nadamw(
+ params: List[Tensor],
+ grads: List[Tensor],
+ exp_avgs: List[Tensor],
+ exp_avg_sqs: List[Tensor],
+ state_steps: List[Tensor],
+ beta1: float,
+ beta2: float,
+ lr: float,
+ weight_decay: float,
+ eps: float,
+):
r"""Functional API that performs NAdamW algorithm computation.
- See NAdamW class for details.
- """
+ See NAdamW class for details.
+ """
if not all(isinstance(t, torch.Tensor) for t in state_steps):
raise RuntimeError(
- 'API has changed, `state_steps` argument must contain a list of' +
- ' singleton tensors')
+ 'API has changed, `state_steps` argument must contain a list of'
+ + ' singleton tensors'
+ )
for i, param in enumerate(params):
grad = grads[i]
@@ -185,28 +195,32 @@ def nadamw(params: List[Tensor],
exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1)
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a NAdamW optimizer and a learning rate schedule."""
del model_state
del rng
epsilon = (
- hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8)
+ hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8
+ )
optimizer_state = {
- 'optimizer':
- NAdamW(
- model_params.parameters(),
- lr=hyperparameters.learning_rate,
- betas=(hyperparameters.beta1, hyperparameters.beta2),
- eps=epsilon,
- weight_decay=hyperparameters.weight_decay),
+ 'optimizer': NAdamW(
+ model_params.parameters(),
+ lr=hyperparameters.learning_rate,
+ betas=(hyperparameters.beta1, hyperparameters.beta2),
+ eps=epsilon,
+ weight_decay=hyperparameters.weight_decay,
+ ),
}
target_setting_step_hint = int(0.75 * workload.step_hint)
optimizer_state['scheduler'] = cosine_warmup.pytorch_cosine_warmup(
- target_setting_step_hint, hyperparameters, optimizer_state['optimizer'])
+ target_setting_step_hint, hyperparameters, optimizer_state['optimizer']
+ )
return optimizer_state
diff --git a/reference_algorithms/target_setting_algorithms/pytorch_nesterov.py b/reference_algorithms/target_setting_algorithms/pytorch_nesterov.py
index 8e10db4ef..ddbcaefdb 100644
--- a/reference_algorithms/target_setting_algorithms/pytorch_nesterov.py
+++ b/reference_algorithms/target_setting_algorithms/pytorch_nesterov.py
@@ -4,40 +4,47 @@
from torch.optim.lr_scheduler import LambdaLR
from algoperf import spec
-from reference_algorithms.target_setting_algorithms.data_selection import \
- data_selection # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.get_batch_size import \
- get_batch_size # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.jax_nesterov import \
- create_lr_schedule_fn
-from reference_algorithms.target_setting_algorithms.pytorch_submission_base import \
- update_params # pylint: disable=unused-import
-
-
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+from reference_algorithms.target_setting_algorithms.data_selection import (
+ data_selection,
+) # pylint: disable=unused-import
+from reference_algorithms.target_setting_algorithms.get_batch_size import (
+ get_batch_size,
+) # pylint: disable=unused-import
+from reference_algorithms.target_setting_algorithms.jax_nesterov import (
+ create_lr_schedule_fn,
+)
+from reference_algorithms.target_setting_algorithms.pytorch_submission_base import (
+ update_params,
+) # pylint: disable=unused-import
+
+
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a Nesterov optimizer and a learning rate schedule."""
del model_state
del rng
# Create optimizer.
optimizer_state = {
- 'optimizer':
- torch.optim.SGD(
- model_params.parameters(),
- lr=hyperparameters.learning_rate,
- momentum=hyperparameters.beta1,
- weight_decay=hyperparameters.weight_decay,
- nesterov=True),
+ 'optimizer': torch.optim.SGD(
+ model_params.parameters(),
+ lr=hyperparameters.learning_rate,
+ momentum=hyperparameters.beta1,
+ weight_decay=hyperparameters.weight_decay,
+ nesterov=True,
+ ),
}
# Create learning rate schedule.
target_setting_step_hint = int(0.75 * workload.step_hint)
- lr_schedule_fn = create_lr_schedule_fn(target_setting_step_hint,
- hyperparameters)
+ lr_schedule_fn = create_lr_schedule_fn(
+ target_setting_step_hint, hyperparameters
+ )
# PyTorch's LambdaLR expects the lr_lambda fn to return a factor which will
# be multiplied with the base lr, so we have to divide by it here.
@@ -45,6 +52,7 @@ def _lr_lambda(step: int) -> float:
return lr_schedule_fn(step).item() / hyperparameters.learning_rate
optimizer_state['scheduler'] = LambdaLR(
- optimizer_state['optimizer'], lr_lambda=_lr_lambda)
+ optimizer_state['optimizer'], lr_lambda=_lr_lambda
+ )
return optimizer_state
diff --git a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py
index bbfd8b0f2..36f736a6b 100644
--- a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py
+++ b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py
@@ -13,18 +13,19 @@
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -36,26 +37,30 @@ def update_params(
optimizer_state['optimizer'].zero_grad()
logits_batch, new_model_state = workload.model_fn(
- params=current_model,
- augmented_and_preprocessed_input_batch=batch,
- model_state=model_state,
- mode=spec.ForwardPassMode.TRAIN,
- rng=rng,
- update_batch_norm=True)
+ params=current_model,
+ augmented_and_preprocessed_input_batch=batch,
+ model_state=model_state,
+ mode=spec.ForwardPassMode.TRAIN,
+ rng=rng,
+ update_batch_norm=True,
+ )
label_smoothing = (
- hyperparameters.label_smoothing if hasattr(hyperparameters,
- 'label_smoothing') else 0.0)
+ hyperparameters.label_smoothing
+ if hasattr(hyperparameters, 'label_smoothing')
+ else 0.0
+ )
if hasattr(hyperparameters, 'grad_clip'):
grad_clip = hyperparameters.grad_clip
else:
grad_clip = None
loss_dict = workload.loss_fn(
- label_batch=batch['targets'],
- logits_batch=logits_batch,
- mask_batch=batch.get('weights'),
- label_smoothing=label_smoothing)
+ label_batch=batch['targets'],
+ logits_batch=logits_batch,
+ mask_batch=batch.get('weights'),
+ label_smoothing=label_smoothing,
+ )
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
if USE_PYTORCH_DDP:
@@ -68,7 +73,8 @@ def update_params(
if grad_clip is not None:
torch.nn.utils.clip_grad_norm_(
- current_model.parameters(), max_norm=grad_clip)
+ current_model.parameters(), max_norm=grad_clip
+ )
optimizer_state['optimizer'].step()
if 'scheduler' in optimizer_state:
optimizer_state['scheduler'].step()
@@ -78,32 +84,39 @@ def update_params(
with torch.no_grad():
parameters = [p for p in current_model.parameters() if p.grad is not None]
grad_norm = torch.norm(
- torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2)
+ torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2
+ )
if workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
- {
- 'loss': loss.item(),
- 'grad_norm': grad_norm.item(),
- }, global_step)
- logging.info('%d) loss = %0.3f, grad_norm = %0.3f',
- global_step,
- loss.item(),
- grad_norm.item())
+ {
+ 'loss': loss.item(),
+ 'grad_norm': grad_norm.item(),
+ },
+ global_step,
+ )
+ logging.info(
+ '%d) loss = %0.3f, grad_norm = %0.3f',
+ global_step,
+ loss.item(),
+ grad_norm.item(),
+ )
return (optimizer_state, current_param_container, new_model_state)
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
From e5209d1dce6b9d04722cef9ad2873be7425bd90e Mon Sep 17 00:00:00 2001
From: Frank Schneider
Date: Mon, 23 Jun 2025 10:19:49 +0200
Subject: [PATCH 104/123] Format tests/
---
tests/modeldiffs/criteo1tb/compare.py | 51 ++-
.../criteo1tb_embed_init/compare.py | 51 ++-
.../modeldiffs/criteo1tb_layernorm/compare.py | 51 ++-
tests/modeldiffs/criteo1tb_resnet/compare.py | 51 ++-
tests/modeldiffs/diff.py | 95 ++--
tests/modeldiffs/fastmri/compare.py | 46 +-
tests/modeldiffs/fastmri_layernorm/compare.py | 46 +-
.../modeldiffs/fastmri_model_size/compare.py | 46 +-
tests/modeldiffs/fastmri_tanh/compare.py | 46 +-
tests/modeldiffs/imagenet_resnet/compare.py | 49 +-
.../imagenet_resnet/gelu_compare.py | 43 +-
.../imagenet_resnet/silu_compare.py | 43 +-
tests/modeldiffs/imagenet_vit/compare.py | 57 +--
tests/modeldiffs/imagenet_vit_glu/compare.py | 43 +-
tests/modeldiffs/imagenet_vit_map/compare.py | 43 +-
.../modeldiffs/imagenet_vit_postln/compare.py | 43 +-
.../librispeech_conformer/compare.py | 47 +-
.../compare.py | 47 +-
.../librispeech_conformer_gelu/compare.py | 47 +-
.../compare.py | 47 +-
.../librispeech_deepspeech/compare.py | 57 ++-
.../compare.py | 47 +-
.../librispeech_deepspeech_normaug/compare.py | 47 +-
.../librispeech_deepspeech_tanh/compare.py | 47 +-
tests/modeldiffs/ogbg/compare.py | 67 +--
tests/modeldiffs/ogbg_gelu/compare.py | 67 +--
tests/modeldiffs/ogbg_model_size/compare.py | 67 +--
tests/modeldiffs/ogbg_silu/compare.py | 67 +--
tests/modeldiffs/torch2jax_utils.py | 31 +-
tests/modeldiffs/vanilla_sgd_jax.py | 27 +-
tests/modeldiffs/vanilla_sgd_pytorch.py | 27 +-
tests/modeldiffs/wmt/compare.py | 64 +--
.../modeldiffs/wmt_attention_temp/compare.py | 69 +--
tests/modeldiffs/wmt_glu_tanh/compare.py | 69 +--
tests/modeldiffs/wmt_post_ln/compare.py | 69 +--
tests/reference_algorithm_tests.py | 428 ++++++++++--------
tests/submission_runner_test.py | 70 +--
tests/test_baselines.py | 94 ++--
tests/test_num_params.py | 154 ++++---
tests/test_param_shapes.py | 113 +++--
tests/test_param_types.py | 213 +++++----
tests/test_ssim.py | 35 +-
tests/test_traindiffs.py | 121 ++---
tests/test_version.py | 4 +-
.../imagenet_jax/workload_test.py | 59 +--
45 files changed, 1726 insertions(+), 1379 deletions(-)
diff --git a/tests/modeldiffs/criteo1tb/compare.py b/tests/modeldiffs/criteo1tb/compare.py
index 9de61a2a5..48f658d06 100644
--- a/tests/modeldiffs/criteo1tb/compare.py
+++ b/tests/modeldiffs/criteo1tb/compare.py
@@ -8,10 +8,12 @@
import torch
from algoperf import spec
-from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import \
- Criteo1TbDlrmSmallWorkload as JaxWorkload
-from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import \
- Criteo1TbDlrmSmallWorkload as PyTorchWorkload
+from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import (
+ Criteo1TbDlrmSmallWorkload as JaxWorkload,
+)
+from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import (
+ Criteo1TbDlrmSmallWorkload as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
@@ -54,31 +56,34 @@ def sd_transform(sd):
pytorch_workload = PyTorchWorkload()
pytorch_batch = {
- 'inputs': torch.ones((2, 13 + 26)),
- 'targets': torch.randint(low=0, high=1, size=(2,)),
- 'weights': torch.ones(2),
+ 'inputs': torch.ones((2, 13 + 26)),
+ 'targets': torch.randint(low=0, high=1, size=(2,)),
+ 'weights': torch.ones(2),
}
jax_batch = {k: np.array(v) for k, v in pytorch_batch.items()}
# Test outputs for identical weights and inputs.
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=sd_transform,
- out_transform=None).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=sd_transform,
+ out_transform=None,
+ ).run()
diff --git a/tests/modeldiffs/criteo1tb_embed_init/compare.py b/tests/modeldiffs/criteo1tb_embed_init/compare.py
index f1897d16f..897920bac 100644
--- a/tests/modeldiffs/criteo1tb_embed_init/compare.py
+++ b/tests/modeldiffs/criteo1tb_embed_init/compare.py
@@ -8,10 +8,12 @@
import torch
from algoperf import spec
-from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import \
- Criteo1TbDlrmSmallEmbedInitWorkload as JaxWorkload
-from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import \
- Criteo1TbDlrmSmallEmbedInitWorkload as PyTorchWorkload
+from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import (
+ Criteo1TbDlrmSmallEmbedInitWorkload as JaxWorkload,
+)
+from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import (
+ Criteo1TbDlrmSmallEmbedInitWorkload as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
@@ -53,31 +55,34 @@ def sd_transform(sd):
pytorch_workload = PyTorchWorkload()
pytorch_batch = {
- 'inputs': torch.ones((2, 13 + 26)),
- 'targets': torch.randint(low=0, high=1, size=(2,)),
- 'weights': torch.ones(2),
+ 'inputs': torch.ones((2, 13 + 26)),
+ 'targets': torch.randint(low=0, high=1, size=(2,)),
+ 'weights': torch.ones(2),
}
jax_batch = {k: np.array(v) for k, v in pytorch_batch.items()}
# Test outputs for identical weights and inputs.
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=sd_transform,
- out_transform=None).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=sd_transform,
+ out_transform=None,
+ ).run()
diff --git a/tests/modeldiffs/criteo1tb_layernorm/compare.py b/tests/modeldiffs/criteo1tb_layernorm/compare.py
index 5aad3cc67..db1dec601 100644
--- a/tests/modeldiffs/criteo1tb_layernorm/compare.py
+++ b/tests/modeldiffs/criteo1tb_layernorm/compare.py
@@ -8,10 +8,12 @@
import torch
from algoperf import spec
-from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import \
- Criteo1TbDlrmSmallLayerNormWorkload as JaxWorkload
-from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import \
- Criteo1TbDlrmSmallLayerNormWorkload as PyTorchWorkload
+from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import (
+ Criteo1TbDlrmSmallLayerNormWorkload as JaxWorkload,
+)
+from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import (
+ Criteo1TbDlrmSmallLayerNormWorkload as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
@@ -65,31 +67,34 @@ def sd_transform(sd):
pytorch_workload = PyTorchWorkload()
pytorch_batch = {
- 'inputs': torch.ones((2, 13 + 26)),
- 'targets': torch.randint(low=0, high=1, size=(2,)),
- 'weights': torch.ones(2),
+ 'inputs': torch.ones((2, 13 + 26)),
+ 'targets': torch.randint(low=0, high=1, size=(2,)),
+ 'weights': torch.ones(2),
}
jax_batch = {k: np.array(v) for k, v in pytorch_batch.items()}
# Test outputs for identical weights and inputs.
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- # mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ # mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=sd_transform,
- out_transform=None).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=sd_transform,
+ out_transform=None,
+ ).run()
diff --git a/tests/modeldiffs/criteo1tb_resnet/compare.py b/tests/modeldiffs/criteo1tb_resnet/compare.py
index 169b1cdf4..4851a8ad4 100644
--- a/tests/modeldiffs/criteo1tb_resnet/compare.py
+++ b/tests/modeldiffs/criteo1tb_resnet/compare.py
@@ -9,10 +9,12 @@
import torch
from algoperf import spec
-from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import \
- Criteo1TbDlrmSmallResNetWorkload as JaxWorkload
-from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import \
- Criteo1TbDlrmSmallResNetWorkload as PyTorchWorkload
+from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import (
+ Criteo1TbDlrmSmallResNetWorkload as JaxWorkload,
+)
+from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import (
+ Criteo1TbDlrmSmallResNetWorkload as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
@@ -65,9 +67,9 @@ def sd_transform(sd):
pytorch_workload = PyTorchWorkload()
pytorch_batch = {
- 'inputs': torch.ones((2, 13 + 26)),
- 'targets': torch.randint(low=0, high=1, size=(2,)),
- 'weights': torch.ones(2),
+ 'inputs': torch.ones((2, 13 + 26)),
+ 'targets': torch.randint(low=0, high=1, size=(2,)),
+ 'weights': torch.ones(2),
}
init_fake_batch_size = 2
@@ -80,23 +82,26 @@ def sd_transform(sd):
# Test outputs for identical weights and inputs.
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=sd_transform,
- out_transform=None).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=sd_transform,
+ out_transform=None,
+ ).run()
diff --git a/tests/modeldiffs/diff.py b/tests/modeldiffs/diff.py
index 52241fd3a..8449c3241 100644
--- a/tests/modeldiffs/diff.py
+++ b/tests/modeldiffs/diff.py
@@ -8,14 +8,17 @@
from tests.modeldiffs.torch2jax_utils import value_transform
-#pylint: disable=dangerous-default-value
-def torch2jax(jax_workload,
- pytorch_workload,
- key_transform=None,
- sd_transform=None,
- init_kwargs=dict(dropout_rate=0.0, aux_dropout_rate=0.0)):
- jax_params, model_state = jax_workload.init_model_fn(jax.random.PRNGKey(0),
- **init_kwargs)
+# pylint: disable=dangerous-default-value
+def torch2jax(
+ jax_workload,
+ pytorch_workload,
+ key_transform=None,
+ sd_transform=None,
+ init_kwargs=dict(dropout_rate=0.0, aux_dropout_rate=0.0),
+):
+ jax_params, model_state = jax_workload.init_model_fn(
+ jax.random.PRNGKey(0), **init_kwargs
+ )
pytorch_model, _ = pytorch_workload.init_model_fn([0], **init_kwargs)
if isinstance(jax_params, dict):
jax_params = FrozenDict(jax_params)
@@ -24,8 +27,9 @@ def torch2jax(jax_workload,
model_state = jax_utils.unreplicate(model_state)
if isinstance(
- pytorch_model,
- (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)):
+ pytorch_model,
+ (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel),
+ ):
pytorch_model = pytorch_model.module
# Map and copy params of pytorch_model to jax_model.
t2j = Torch2Jax(torch_model=pytorch_model, jax_model=jax_params)
@@ -39,22 +43,24 @@ def torch2jax(jax_workload,
return jax_params, model_state, pytorch_model
-def out_diff(jax_workload,
- pytorch_workload,
- jax_model_kwargs,
- pytorch_model_kwargs,
- key_transform=None,
- sd_transform=None,
- out_transform=None):
- jax_params, model_state, pytorch_model = torch2jax(jax_workload,
- pytorch_workload,
- key_transform,
- sd_transform)
- out_p, _ = pytorch_workload.model_fn(params=pytorch_model,
- **pytorch_model_kwargs)
- out_j, _ = jax_workload.model_fn(params=jax_params,
- model_state=model_state,
- **jax_model_kwargs)
+def out_diff(
+ jax_workload,
+ pytorch_workload,
+ jax_model_kwargs,
+ pytorch_model_kwargs,
+ key_transform=None,
+ sd_transform=None,
+ out_transform=None,
+):
+ jax_params, model_state, pytorch_model = torch2jax(
+ jax_workload, pytorch_workload, key_transform, sd_transform
+ )
+ out_p, _ = pytorch_workload.model_fn(
+ params=pytorch_model, **pytorch_model_kwargs
+ )
+ out_j, _ = jax_workload.model_fn(
+ params=jax_params, model_state=model_state, **jax_model_kwargs
+ )
if out_transform is not None:
out_p = out_transform(out_p)
out_j = out_transform(out_j)
@@ -67,15 +73,16 @@ def out_diff(jax_workload,
class ModelDiffRunner:
-
- def __init__(self,
- jax_workload,
- pytorch_workload,
- jax_model_kwargs,
- pytorch_model_kwargs,
- key_transform=None,
- sd_transform=None,
- out_transform=None) -> None:
+ def __init__(
+ self,
+ jax_workload,
+ pytorch_workload,
+ jax_model_kwargs,
+ pytorch_model_kwargs,
+ key_transform=None,
+ sd_transform=None,
+ out_transform=None,
+ ) -> None:
"""
Initializes the instance based on diffing logic.
@@ -83,7 +90,7 @@ def __init__(self,
jax_workload: Workload implementation using JAX.
pytorch_workload: Workload implementation using PyTorch.
jax_model_kwargs: Arguments to be used for model_fn in jax workload.
- pytorch_model_kwargs: Arguments to be used for model_fn in PyTorch
+ pytorch_model_kwargs: Arguments to be used for model_fn in PyTorch
workload.
key_transform: Transformation function for keys.
sd_transform: Transformation function for State Dictionary.
@@ -99,10 +106,12 @@ def __init__(self,
self.out_transform = out_transform
def run(self):
- out_diff(self.jax_workload,
- self.pytorch_workload,
- self.jax_model_kwargs,
- self.pytorch_model_kwargs,
- self.key_transform,
- self.sd_transform,
- self.out_transform)
+ out_diff(
+ self.jax_workload,
+ self.pytorch_workload,
+ self.jax_model_kwargs,
+ self.pytorch_model_kwargs,
+ self.key_transform,
+ self.sd_transform,
+ self.out_transform,
+ )
diff --git a/tests/modeldiffs/fastmri/compare.py b/tests/modeldiffs/fastmri/compare.py
index c1a349cec..6a82bfb58 100644
--- a/tests/modeldiffs/fastmri/compare.py
+++ b/tests/modeldiffs/fastmri/compare.py
@@ -7,15 +7,16 @@
import torch
from algoperf import spec
-from algoperf.workloads.fastmri.fastmri_jax.workload import \
- FastMRIWorkload as JaxWorkload
-from algoperf.workloads.fastmri.fastmri_pytorch.workload import \
- FastMRIWorkload as PyTorchWorkload
+from algoperf.workloads.fastmri.fastmri_jax.workload import (
+ FastMRIWorkload as JaxWorkload,
+)
+from algoperf.workloads.fastmri.fastmri_pytorch.workload import (
+ FastMRIWorkload as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
def sd_transform(sd):
-
def sort_key(k):
if k[0] == 'ModuleList_0':
return (0, *k)
@@ -34,7 +35,7 @@ def sort_key(k):
if 'ModuleList' in i or 'Sequential' in i:
continue
if i.startswith('ConvBlock'):
- if idx != 0 and keys[idx - 1][:idx2 + 1] != k[:idx2 + 1]:
+ if idx != 0 and keys[idx - 1][: idx2 + 1] != k[: idx2 + 1]:
c += 1
i = f'ConvBlock_{c}'
if 'Conv2d' in i:
@@ -64,22 +65,25 @@ def sort_key(k):
pytorch_batch = {'inputs': image}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=None,
- sd_transform=sd_transform).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=None,
+ sd_transform=sd_transform,
+ ).run()
diff --git a/tests/modeldiffs/fastmri_layernorm/compare.py b/tests/modeldiffs/fastmri_layernorm/compare.py
index f26ad185e..8ad47bcae 100644
--- a/tests/modeldiffs/fastmri_layernorm/compare.py
+++ b/tests/modeldiffs/fastmri_layernorm/compare.py
@@ -7,15 +7,16 @@
import torch
from algoperf import spec
-from algoperf.workloads.fastmri.fastmri_jax.workload import \
- FastMRILayerNormWorkload as JaxWorkload
-from algoperf.workloads.fastmri.fastmri_pytorch.workload import \
- FastMRILayerNormWorkload as PyTorchWorkload
+from algoperf.workloads.fastmri.fastmri_jax.workload import (
+ FastMRILayerNormWorkload as JaxWorkload,
+)
+from algoperf.workloads.fastmri.fastmri_pytorch.workload import (
+ FastMRILayerNormWorkload as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
def sd_transform(sd):
-
def sort_key(k):
if k[0] == 'ModuleList_0':
return (0, *k)
@@ -35,7 +36,7 @@ def sort_key(k):
if 'ModuleList' in i or 'Sequential' in i:
continue
if i.startswith('ConvBlock'):
- if idx != 0 and keys[idx - 1][:idx2 + 1] != k[:idx2 + 1]:
+ if idx != 0 and keys[idx - 1][: idx2 + 1] != k[: idx2 + 1]:
c += 1
i = f'ConvBlock_{c}'
if 'Conv2d' in i:
@@ -71,22 +72,25 @@ def sort_key(k):
pytorch_batch = {'inputs': image}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=None,
- sd_transform=sd_transform).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=None,
+ sd_transform=sd_transform,
+ ).run()
diff --git a/tests/modeldiffs/fastmri_model_size/compare.py b/tests/modeldiffs/fastmri_model_size/compare.py
index 42789539b..f6d5c5074 100644
--- a/tests/modeldiffs/fastmri_model_size/compare.py
+++ b/tests/modeldiffs/fastmri_model_size/compare.py
@@ -7,15 +7,16 @@
import torch
from algoperf import spec
-from algoperf.workloads.fastmri.fastmri_jax.workload import \
- FastMRIModelSizeWorkload as JaxWorkload
-from algoperf.workloads.fastmri.fastmri_pytorch.workload import \
- FastMRIModelSizeWorkload as PyTorchWorkload
+from algoperf.workloads.fastmri.fastmri_jax.workload import (
+ FastMRIModelSizeWorkload as JaxWorkload,
+)
+from algoperf.workloads.fastmri.fastmri_pytorch.workload import (
+ FastMRIModelSizeWorkload as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
def sd_transform(sd):
-
def sort_key(k):
if k[0] == 'ModuleList_0':
return (0, *k)
@@ -34,7 +35,7 @@ def sort_key(k):
if 'ModuleList' in i or 'Sequential' in i:
continue
if i.startswith('ConvBlock'):
- if idx != 0 and keys[idx - 1][:idx2 + 1] != k[:idx2 + 1]:
+ if idx != 0 and keys[idx - 1][: idx2 + 1] != k[: idx2 + 1]:
c += 1
i = f'ConvBlock_{c}'
if 'Conv2d' in i:
@@ -64,22 +65,25 @@ def sort_key(k):
pytorch_batch = {'inputs': image}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=None,
- sd_transform=sd_transform).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=None,
+ sd_transform=sd_transform,
+ ).run()
diff --git a/tests/modeldiffs/fastmri_tanh/compare.py b/tests/modeldiffs/fastmri_tanh/compare.py
index 13ecb890c..714a025b3 100644
--- a/tests/modeldiffs/fastmri_tanh/compare.py
+++ b/tests/modeldiffs/fastmri_tanh/compare.py
@@ -7,15 +7,16 @@
import torch
from algoperf import spec
-from algoperf.workloads.fastmri.fastmri_jax.workload import \
- FastMRITanhWorkload as JaxWorkload
-from algoperf.workloads.fastmri.fastmri_pytorch.workload import \
- FastMRITanhWorkload as PyTorchWorkload
+from algoperf.workloads.fastmri.fastmri_jax.workload import (
+ FastMRITanhWorkload as JaxWorkload,
+)
+from algoperf.workloads.fastmri.fastmri_pytorch.workload import (
+ FastMRITanhWorkload as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
def sd_transform(sd):
-
def sort_key(k):
if k[0] == 'ModuleList_0':
return (0, *k)
@@ -34,7 +35,7 @@ def sort_key(k):
if 'ModuleList' in i or 'Sequential' in i:
continue
if i.startswith('ConvBlock'):
- if idx != 0 and keys[idx - 1][:idx2 + 1] != k[:idx2 + 1]:
+ if idx != 0 and keys[idx - 1][: idx2 + 1] != k[: idx2 + 1]:
c += 1
i = f'ConvBlock_{c}'
if 'Conv2d' in i:
@@ -64,22 +65,25 @@ def sort_key(k):
pytorch_batch = {'inputs': image}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=None,
- sd_transform=sd_transform).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=None,
+ sd_transform=sd_transform,
+ ).run()
diff --git a/tests/modeldiffs/imagenet_resnet/compare.py b/tests/modeldiffs/imagenet_resnet/compare.py
index 59ab45555..e43cd069e 100644
--- a/tests/modeldiffs/imagenet_resnet/compare.py
+++ b/tests/modeldiffs/imagenet_resnet/compare.py
@@ -7,10 +7,12 @@
import torch
from algoperf import spec
-from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import \
- ImagenetResNetWorkload as JaxWorkload
-from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import \
- ImagenetResNetWorkload as PyTorchWorkload
+from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import (
+ ImagenetResNetWorkload as JaxWorkload,
+)
+from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import (
+ ImagenetResNetWorkload as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
@@ -53,11 +55,13 @@ def sd_transform(sd):
c += 1
new_key = (f'BottleneckResNetBlock_{c}',) + k[2:]
if 'Sequential' in ''.join(new_key):
- new_key = tuple([
+ new_key = tuple(
+ [
(i.replace('_0', '_proj') if 'BatchNorm' in i or 'Conv' in i else i)
for i in new_key
if 'Sequential' not in i
- ])
+ ]
+ )
sd[new_key] = sd[k]
del sd[k]
elif 'BatchNorm' in k[0] or 'Conv' in k[0]:
@@ -81,22 +85,25 @@ def sd_transform(sd):
pytorch_batch = {'inputs': image}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=sd_transform).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=sd_transform,
+ ).run()
diff --git a/tests/modeldiffs/imagenet_resnet/gelu_compare.py b/tests/modeldiffs/imagenet_resnet/gelu_compare.py
index 07510ad70..8aa48382d 100644
--- a/tests/modeldiffs/imagenet_resnet/gelu_compare.py
+++ b/tests/modeldiffs/imagenet_resnet/gelu_compare.py
@@ -7,10 +7,12 @@
import torch
from algoperf import spec
-from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import \
- ImagenetResNetGELUWorkload as JaxWorkload
-from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import \
- ImagenetResNetGELUWorkload as PyTorchWorkload
+from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import (
+ ImagenetResNetGELUWorkload as JaxWorkload,
+)
+from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import (
+ ImagenetResNetGELUWorkload as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
from tests.modeldiffs.imagenet_resnet.compare import key_transform
from tests.modeldiffs.imagenet_resnet.compare import sd_transform
@@ -28,22 +30,25 @@
pytorch_batch = {'inputs': image}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=sd_transform).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=sd_transform,
+ ).run()
diff --git a/tests/modeldiffs/imagenet_resnet/silu_compare.py b/tests/modeldiffs/imagenet_resnet/silu_compare.py
index 8246d17a2..393badd18 100644
--- a/tests/modeldiffs/imagenet_resnet/silu_compare.py
+++ b/tests/modeldiffs/imagenet_resnet/silu_compare.py
@@ -7,10 +7,12 @@
import torch
from algoperf import spec
-from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import \
- ImagenetResNetSiLUWorkload as JaxWorkload
-from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import \
- ImagenetResNetSiLUWorkload as PyTorchWorkload
+from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import (
+ ImagenetResNetSiLUWorkload as JaxWorkload,
+)
+from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import (
+ ImagenetResNetSiLUWorkload as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
from tests.modeldiffs.imagenet_resnet.compare import key_transform
from tests.modeldiffs.imagenet_resnet.compare import sd_transform
@@ -28,22 +30,25 @@
pytorch_batch = {'inputs': image}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=sd_transform).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=sd_transform,
+ ).run()
diff --git a/tests/modeldiffs/imagenet_vit/compare.py b/tests/modeldiffs/imagenet_vit/compare.py
index b4ca7d8ec..84282d4be 100644
--- a/tests/modeldiffs/imagenet_vit/compare.py
+++ b/tests/modeldiffs/imagenet_vit/compare.py
@@ -7,10 +7,12 @@
import torch
from algoperf import spec
-from algoperf.workloads.imagenet_vit.imagenet_jax.workload import \
- ImagenetVitWorkload as JaxWorkload
-from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import \
- ImagenetVitWorkload as PyTorchWorkload
+from algoperf.workloads.imagenet_vit.imagenet_jax.workload import (
+ ImagenetVitWorkload as JaxWorkload,
+)
+from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import (
+ ImagenetVitWorkload as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
@@ -40,16 +42,16 @@ def key_transform(k):
if attention:
if pool_head:
i = {
- 'Linear_0': 'query',
- 'Linear_1': 'key_value',
- 'Linear_2': 'out',
+ 'Linear_0': 'query',
+ 'Linear_1': 'key_value',
+ 'Linear_2': 'out',
}[i]
else:
i = {
- 'Linear_0': 'query',
- 'Linear_1': 'key',
- 'Linear_2': 'value',
- 'Linear_3': 'out',
+ 'Linear_0': 'query',
+ 'Linear_1': 'key',
+ 'Linear_2': 'value',
+ 'Linear_3': 'out',
}[i]
else:
i = i.replace('Linear', 'Dense')
@@ -94,22 +96,25 @@ def key_transform(k):
pytorch_batch = {'inputs': image}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=None).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=None,
+ ).run()
diff --git a/tests/modeldiffs/imagenet_vit_glu/compare.py b/tests/modeldiffs/imagenet_vit_glu/compare.py
index c152410b5..55c010b97 100644
--- a/tests/modeldiffs/imagenet_vit_glu/compare.py
+++ b/tests/modeldiffs/imagenet_vit_glu/compare.py
@@ -10,10 +10,12 @@
import torch
from algoperf import spec
-from algoperf.workloads.imagenet_vit.imagenet_jax.workload import \
- ImagenetVitGluWorkload as JaxWorkload
-from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import \
- ImagenetVitGluWorkload as PyTorchWorkload
+from algoperf.workloads.imagenet_vit.imagenet_jax.workload import (
+ ImagenetVitGluWorkload as JaxWorkload,
+)
+from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import (
+ ImagenetVitGluWorkload as PyTorchWorkload,
+)
sd_transform = None
@@ -30,22 +32,25 @@
pytorch_batch = {'inputs': image}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=None).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=None,
+ ).run()
diff --git a/tests/modeldiffs/imagenet_vit_map/compare.py b/tests/modeldiffs/imagenet_vit_map/compare.py
index 7f1af41ab..17a7483c2 100644
--- a/tests/modeldiffs/imagenet_vit_map/compare.py
+++ b/tests/modeldiffs/imagenet_vit_map/compare.py
@@ -10,10 +10,12 @@
import torch
from algoperf import spec
-from algoperf.workloads.imagenet_vit.imagenet_jax.workload import \
- ImagenetVitMapWorkload as JaxWorkload
-from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import \
- ImagenetVitMapWorkload as PytWorkload
+from algoperf.workloads.imagenet_vit.imagenet_jax.workload import (
+ ImagenetVitMapWorkload as JaxWorkload,
+)
+from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import (
+ ImagenetVitMapWorkload as PytWorkload,
+)
def sd_transform(sd):
@@ -41,22 +43,25 @@ def sd_transform(sd):
pytorch_batch = {'inputs': image}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=sd_transform).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=sd_transform,
+ ).run()
diff --git a/tests/modeldiffs/imagenet_vit_postln/compare.py b/tests/modeldiffs/imagenet_vit_postln/compare.py
index a3a639101..72d407a4b 100644
--- a/tests/modeldiffs/imagenet_vit_postln/compare.py
+++ b/tests/modeldiffs/imagenet_vit_postln/compare.py
@@ -10,10 +10,12 @@
import torch
from algoperf import spec
-from algoperf.workloads.imagenet_vit.imagenet_jax.workload import \
- ImagenetVitPostLNWorkload as JaxWorkload
-from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import \
- ImagenetViTPostLNWorkload as PyTorchWorkload
+from algoperf.workloads.imagenet_vit.imagenet_jax.workload import (
+ ImagenetVitPostLNWorkload as JaxWorkload,
+)
+from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import (
+ ImagenetViTPostLNWorkload as PyTorchWorkload,
+)
sd_transform = None
@@ -30,22 +32,25 @@
pytorch_batch = {'inputs': image}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=None).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=None,
+ ).run()
diff --git a/tests/modeldiffs/librispeech_conformer/compare.py b/tests/modeldiffs/librispeech_conformer/compare.py
index 664b1242d..80aba62f2 100644
--- a/tests/modeldiffs/librispeech_conformer/compare.py
+++ b/tests/modeldiffs/librispeech_conformer/compare.py
@@ -7,10 +7,12 @@
import torch
from algoperf import spec
-from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import \
- LibriSpeechConformerWorkload as JaxWorkload
-from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import \
- LibriSpeechConformerWorkload as PyTorchWorkload
+from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import (
+ LibriSpeechConformerWorkload as JaxWorkload,
+)
+from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import (
+ LibriSpeechConformerWorkload as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
@@ -69,24 +71,27 @@ def sd_transform(sd):
pytorch_batch = {'inputs': (wave, pad)}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=sd_transform,
- out_transform=lambda out_outpad: out_outpad[0] *
- (1 - out_outpad[1][:, :, None])).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=sd_transform,
+ out_transform=lambda out_outpad: out_outpad[0]
+ * (1 - out_outpad[1][:, :, None]),
+ ).run()
diff --git a/tests/modeldiffs/librispeech_conformer_attention_temperature/compare.py b/tests/modeldiffs/librispeech_conformer_attention_temperature/compare.py
index b0812e77d..a7bebf6bf 100644
--- a/tests/modeldiffs/librispeech_conformer_attention_temperature/compare.py
+++ b/tests/modeldiffs/librispeech_conformer_attention_temperature/compare.py
@@ -7,10 +7,12 @@
import torch
from algoperf import spec
-from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import \
- LibriSpeechConformerAttentionTemperatureWorkload as JaxWorkload
-from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import \
- LibriSpeechConformerAttentionTemperatureWorkload as PyTorchWorkload
+from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import (
+ LibriSpeechConformerAttentionTemperatureWorkload as JaxWorkload,
+)
+from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import (
+ LibriSpeechConformerAttentionTemperatureWorkload as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
@@ -69,24 +71,27 @@ def sd_transform(sd):
pytorch_batch = {'inputs': (wave, pad)}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=sd_transform,
- out_transform=lambda out_outpad: out_outpad[0] *
- (1 - out_outpad[1][:, :, None])).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=sd_transform,
+ out_transform=lambda out_outpad: out_outpad[0]
+ * (1 - out_outpad[1][:, :, None]),
+ ).run()
diff --git a/tests/modeldiffs/librispeech_conformer_gelu/compare.py b/tests/modeldiffs/librispeech_conformer_gelu/compare.py
index 3032a0005..d8f7980a2 100644
--- a/tests/modeldiffs/librispeech_conformer_gelu/compare.py
+++ b/tests/modeldiffs/librispeech_conformer_gelu/compare.py
@@ -7,10 +7,12 @@
import torch
from algoperf import spec
-from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import \
- LibriSpeechConformerGeluWorkload as JaxWorkload
-from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import \
- LibriSpeechConformerGeluWorkload as PyTorchWorkload
+from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import (
+ LibriSpeechConformerGeluWorkload as JaxWorkload,
+)
+from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import (
+ LibriSpeechConformerGeluWorkload as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
@@ -69,24 +71,27 @@ def sd_transform(sd):
pytorch_batch = {'inputs': (wave, pad)}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=sd_transform,
- out_transform=lambda out_outpad: out_outpad[0] *
- (1 - out_outpad[1][:, :, None])).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=sd_transform,
+ out_transform=lambda out_outpad: out_outpad[0]
+ * (1 - out_outpad[1][:, :, None]),
+ ).run()
diff --git a/tests/modeldiffs/librispeech_conformer_layernorm/compare.py b/tests/modeldiffs/librispeech_conformer_layernorm/compare.py
index d623ef352..7f4768c11 100644
--- a/tests/modeldiffs/librispeech_conformer_layernorm/compare.py
+++ b/tests/modeldiffs/librispeech_conformer_layernorm/compare.py
@@ -7,10 +7,12 @@
import torch
from algoperf import spec
-from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import \
- LibriSpeechConformerLayerNormWorkload as JaxWorkload
-from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import \
- LibriSpeechConformerLayerNormWorkload as PyTorchWorkload
+from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import (
+ LibriSpeechConformerLayerNormWorkload as JaxWorkload,
+)
+from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import (
+ LibriSpeechConformerLayerNormWorkload as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
@@ -69,24 +71,27 @@ def sd_transform(sd):
pytorch_batch = {'inputs': (wave, pad)}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=sd_transform,
- out_transform=lambda out_outpad: out_outpad[0] *
- (1 - out_outpad[1][:, :, None])).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=sd_transform,
+ out_transform=lambda out_outpad: out_outpad[0]
+ * (1 - out_outpad[1][:, :, None]),
+ ).run()
diff --git a/tests/modeldiffs/librispeech_deepspeech/compare.py b/tests/modeldiffs/librispeech_deepspeech/compare.py
index 84b0a6c86..12cd11513 100644
--- a/tests/modeldiffs/librispeech_deepspeech/compare.py
+++ b/tests/modeldiffs/librispeech_deepspeech/compare.py
@@ -7,10 +7,12 @@
import torch
from algoperf import spec
-from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import \
- LibriSpeechDeepSpeechWorkload as JaxWorkload
-from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import \
- LibriSpeechDeepSpeechWorkload as PyTorchWorkload
+from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import (
+ LibriSpeechDeepSpeechWorkload as JaxWorkload,
+)
+from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import (
+ LibriSpeechDeepSpeechWorkload as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
@@ -67,10 +69,12 @@ def sd_transform(sd):
if isinstance(out[k], dict):
kernels = ['kernel_ih_l0', 'kernel_hh_l0']
biases = ['bias_ih_l0', 'bias_hh_l0']
- weights = torch.cat([out[k][i].view(-1) for i in kernels] +
- [out[k][i + '_reverse'].view(-1) for i in kernels] +
- [out[k][i].view(-1) for i in biases] +
- [out[k][i + '_reverse'].view(-1) for i in biases])
+ weights = torch.cat(
+ [out[k][i].view(-1) for i in kernels]
+ + [out[k][i + '_reverse'].view(-1) for i in kernels]
+ + [out[k][i].view(-1) for i in biases]
+ + [out[k][i + '_reverse'].view(-1) for i in biases]
+ )
updates[k + ('weights',)] = weights
keys_to_del.append(k)
out.update(updates)
@@ -94,24 +98,27 @@ def sd_transform(sd):
pytorch_batch = {'inputs': (wave, pad)}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=sd_transform,
- out_transform=lambda out_outpad: out_outpad[0] *
- (1 - out_outpad[1][:, :, None])).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=sd_transform,
+ out_transform=lambda out_outpad: out_outpad[0]
+ * (1 - out_outpad[1][:, :, None]),
+ ).run()
diff --git a/tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py b/tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py
index 2540c1b93..6a719a84a 100644
--- a/tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py
+++ b/tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py
@@ -7,10 +7,12 @@
import torch
from algoperf import spec
-from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import \
- LibriSpeechDeepSpeechTanhWorkload as JaxWorkload
-from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import \
- LibriSpeechDeepSpeechTanhWorkload as PyTorchWorkload
+from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import (
+ LibriSpeechDeepSpeechTanhWorkload as JaxWorkload,
+)
+from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import (
+ LibriSpeechDeepSpeechTanhWorkload as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
from tests.modeldiffs.librispeech_deepspeech.compare import key_transform
from tests.modeldiffs.librispeech_deepspeech.compare import sd_transform
@@ -30,24 +32,27 @@
pytorch_batch = {'inputs': (wave, pad)}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=sd_transform,
- out_transform=lambda out_outpad: out_outpad[0] *
- (1 - out_outpad[1][:, :, None])).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=sd_transform,
+ out_transform=lambda out_outpad: out_outpad[0]
+ * (1 - out_outpad[1][:, :, None]),
+ ).run()
diff --git a/tests/modeldiffs/librispeech_deepspeech_normaug/compare.py b/tests/modeldiffs/librispeech_deepspeech_normaug/compare.py
index e5972120d..c8820d397 100644
--- a/tests/modeldiffs/librispeech_deepspeech_normaug/compare.py
+++ b/tests/modeldiffs/librispeech_deepspeech_normaug/compare.py
@@ -7,10 +7,12 @@
import torch
from algoperf import spec
-from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import \
- LibriSpeechDeepSpeechNormAndSpecAugWorkload as JaxWorkload
-from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import \
- LibriSpeechDeepSpeechNormAndSpecAugWorkload as PyTorchWorkload
+from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import (
+ LibriSpeechDeepSpeechNormAndSpecAugWorkload as JaxWorkload,
+)
+from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import (
+ LibriSpeechDeepSpeechNormAndSpecAugWorkload as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
from tests.modeldiffs.librispeech_deepspeech.compare import key_transform
from tests.modeldiffs.librispeech_deepspeech.compare import sd_transform
@@ -30,24 +32,27 @@
pytorch_batch = {'inputs': (wave, pad)}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=sd_transform,
- out_transform=lambda out_outpad: out_outpad[0] *
- (1 - out_outpad[1][:, :, None])).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=sd_transform,
+ out_transform=lambda out_outpad: out_outpad[0]
+ * (1 - out_outpad[1][:, :, None]),
+ ).run()
diff --git a/tests/modeldiffs/librispeech_deepspeech_tanh/compare.py b/tests/modeldiffs/librispeech_deepspeech_tanh/compare.py
index 4d2c4a5d5..0882f3d1e 100644
--- a/tests/modeldiffs/librispeech_deepspeech_tanh/compare.py
+++ b/tests/modeldiffs/librispeech_deepspeech_tanh/compare.py
@@ -7,10 +7,12 @@
import torch
from algoperf import spec
-from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import \
- LibriSpeechDeepSpeechNoResNetWorkload as JaxWorkload
-from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import \
- LibriSpeechDeepSpeechNoResNetWorkload as PyTorchWorkload
+from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import (
+ LibriSpeechDeepSpeechNoResNetWorkload as JaxWorkload,
+)
+from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import (
+ LibriSpeechDeepSpeechNoResNetWorkload as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
from tests.modeldiffs.librispeech_deepspeech.compare import key_transform
from tests.modeldiffs.librispeech_deepspeech.compare import sd_transform
@@ -30,24 +32,27 @@
pytorch_batch = {'inputs': (wave, pad)}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=sd_transform,
- out_transform=lambda out_outpad: out_outpad[0] *
- (1 - out_outpad[1][:, :, None])).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=sd_transform,
+ out_transform=lambda out_outpad: out_outpad[0]
+ * (1 - out_outpad[1][:, :, None]),
+ ).run()
diff --git a/tests/modeldiffs/ogbg/compare.py b/tests/modeldiffs/ogbg/compare.py
index 5d5ef50bf..8c40d3c8a 100644
--- a/tests/modeldiffs/ogbg/compare.py
+++ b/tests/modeldiffs/ogbg/compare.py
@@ -9,10 +9,12 @@
import torch
from algoperf import spec
-from algoperf.workloads.ogbg.ogbg_jax.workload import \
- OgbgWorkload as JaxWorkload
-from algoperf.workloads.ogbg.ogbg_pytorch.workload import \
- OgbgWorkload as PyTorchWorkload
+from algoperf.workloads.ogbg.ogbg_jax.workload import (
+ OgbgWorkload as JaxWorkload,
+)
+from algoperf.workloads.ogbg.ogbg_pytorch.workload import (
+ OgbgWorkload as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
# Todo: refactor tests to use workload properties in cleaner way
@@ -41,8 +43,8 @@ def key_transform(k):
layer_index = int(i.split('_')[1])
if graph_network:
count = (
- graph_index * 3 * hidden_dims + seq_index * hidden_dims +
- layer_index)
+ graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index
+ )
i = 'Dense_' + str(count)
elif layer_index == 0:
i = 'node_embedding'
@@ -54,7 +56,8 @@ def key_transform(k):
elif 'LayerNorm' in i:
layer_index = int(i.split('_')[1])
count = (
- graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index)
+ graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index
+ )
i = 'LayerNorm_' + str(count)
elif 'weight' in i:
if bn or ln:
@@ -80,13 +83,14 @@ def sd_transform(sd):
pytorch_workload = PyTorchWorkload()
pytorch_batch = dict(
- n_node=torch.LongTensor([5]),
- n_edge=torch.LongTensor([5]),
- nodes=torch.randn(5, 9),
- edges=torch.randn(5, 3),
- globals=torch.randn(1, 128),
- senders=torch.LongTensor(list(range(5))),
- receivers=torch.LongTensor([(i + 1) % 5 for i in range(5)]))
+ n_node=torch.LongTensor([5]),
+ n_edge=torch.LongTensor([5]),
+ nodes=torch.randn(5, 9),
+ edges=torch.randn(5, 3),
+ globals=torch.randn(1, 128),
+ senders=torch.LongTensor(list(range(5))),
+ receivers=torch.LongTensor([(i + 1) % 5 for i in range(5)]),
+ )
jax_batch = {k: np.array(v) for k, v in pytorch_batch.items()}
@@ -98,23 +102,26 @@ def sd_transform(sd):
pytorch_batch = {'inputs': graph_p}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=sd_transform,
- out_transform=None).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=sd_transform,
+ out_transform=None,
+ ).run()
diff --git a/tests/modeldiffs/ogbg_gelu/compare.py b/tests/modeldiffs/ogbg_gelu/compare.py
index fc3992998..f35ed8b17 100644
--- a/tests/modeldiffs/ogbg_gelu/compare.py
+++ b/tests/modeldiffs/ogbg_gelu/compare.py
@@ -9,10 +9,12 @@
import torch
from algoperf import spec
-from algoperf.workloads.ogbg.ogbg_jax.workload import \
- OgbgGeluWorkload as JaxWorkload
-from algoperf.workloads.ogbg.ogbg_pytorch.workload import \
- OgbgGeluWorkload as PyTorchWorkload
+from algoperf.workloads.ogbg.ogbg_jax.workload import (
+ OgbgGeluWorkload as JaxWorkload,
+)
+from algoperf.workloads.ogbg.ogbg_pytorch.workload import (
+ OgbgGeluWorkload as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
# Todo: refactor tests to use workload properties in cleaner way
@@ -41,8 +43,8 @@ def key_transform(k):
layer_index = int(i.split('_')[1])
if graph_network:
count = (
- graph_index * 3 * hidden_dims + seq_index * hidden_dims +
- layer_index)
+ graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index
+ )
i = 'Dense_' + str(count)
elif layer_index == 0:
i = 'node_embedding'
@@ -54,7 +56,8 @@ def key_transform(k):
elif 'LayerNorm' in i:
layer_index = int(i.split('_')[1])
count = (
- graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index)
+ graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index
+ )
i = 'LayerNorm_' + str(count)
elif 'weight' in i:
if bn or ln:
@@ -80,13 +83,14 @@ def sd_transform(sd):
pytorch_workload = PyTorchWorkload()
pytorch_batch = dict(
- n_node=torch.LongTensor([5]),
- n_edge=torch.LongTensor([5]),
- nodes=torch.randn(5, 9),
- edges=torch.randn(5, 3),
- globals=torch.randn(1, 128),
- senders=torch.LongTensor(list(range(5))),
- receivers=torch.LongTensor([(i + 1) % 5 for i in range(5)]))
+ n_node=torch.LongTensor([5]),
+ n_edge=torch.LongTensor([5]),
+ nodes=torch.randn(5, 9),
+ edges=torch.randn(5, 3),
+ globals=torch.randn(1, 128),
+ senders=torch.LongTensor(list(range(5))),
+ receivers=torch.LongTensor([(i + 1) % 5 for i in range(5)]),
+ )
jax_batch = {k: np.array(v) for k, v in pytorch_batch.items()}
@@ -98,23 +102,26 @@ def sd_transform(sd):
pytorch_batch = {'inputs': graph_p}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=sd_transform,
- out_transform=None).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=sd_transform,
+ out_transform=None,
+ ).run()
diff --git a/tests/modeldiffs/ogbg_model_size/compare.py b/tests/modeldiffs/ogbg_model_size/compare.py
index e7cfa745c..0042c71af 100644
--- a/tests/modeldiffs/ogbg_model_size/compare.py
+++ b/tests/modeldiffs/ogbg_model_size/compare.py
@@ -9,10 +9,12 @@
import torch
from algoperf import spec
-from algoperf.workloads.ogbg.ogbg_jax.workload import \
- OgbgModelSizeWorkload as JaxWorkload
-from algoperf.workloads.ogbg.ogbg_pytorch.workload import \
- OgbgModelSizeWorkload as PyTorchWorkload
+from algoperf.workloads.ogbg.ogbg_jax.workload import (
+ OgbgModelSizeWorkload as JaxWorkload,
+)
+from algoperf.workloads.ogbg.ogbg_pytorch.workload import (
+ OgbgModelSizeWorkload as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
# Todo: refactor tests to use workload properties in cleaner way
@@ -41,8 +43,8 @@ def key_transform(k):
layer_index = int(i.split('_')[1])
if graph_network:
count = (
- graph_index * 3 * hidden_dims + seq_index * hidden_dims +
- layer_index)
+ graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index
+ )
i = 'Dense_' + str(count)
elif layer_index == 0:
i = 'node_embedding'
@@ -54,7 +56,8 @@ def key_transform(k):
elif 'LayerNorm' in i:
layer_index = int(i.split('_')[1])
count = (
- graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index)
+ graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index
+ )
i = 'LayerNorm_' + str(count)
elif 'weight' in i:
if bn or ln:
@@ -80,13 +83,14 @@ def sd_transform(sd):
pytorch_workload = PyTorchWorkload()
pytorch_batch = dict(
- n_node=torch.LongTensor([5]),
- n_edge=torch.LongTensor([5]),
- nodes=torch.randn(5, 9),
- edges=torch.randn(5, 3),
- globals=torch.randn(1, 128),
- senders=torch.LongTensor(list(range(5))),
- receivers=torch.LongTensor([(i + 1) % 5 for i in range(5)]))
+ n_node=torch.LongTensor([5]),
+ n_edge=torch.LongTensor([5]),
+ nodes=torch.randn(5, 9),
+ edges=torch.randn(5, 3),
+ globals=torch.randn(1, 128),
+ senders=torch.LongTensor(list(range(5))),
+ receivers=torch.LongTensor([(i + 1) % 5 for i in range(5)]),
+ )
jax_batch = {k: np.array(v) for k, v in pytorch_batch.items()}
@@ -98,23 +102,26 @@ def sd_transform(sd):
pytorch_batch = {'inputs': graph_p}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=sd_transform,
- out_transform=None).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=sd_transform,
+ out_transform=None,
+ ).run()
diff --git a/tests/modeldiffs/ogbg_silu/compare.py b/tests/modeldiffs/ogbg_silu/compare.py
index 4e3b96cf7..7583282cd 100644
--- a/tests/modeldiffs/ogbg_silu/compare.py
+++ b/tests/modeldiffs/ogbg_silu/compare.py
@@ -9,10 +9,12 @@
import torch
from algoperf import spec
-from algoperf.workloads.ogbg.ogbg_jax.workload import \
- OgbgSiluWorkload as JaxWorkload
-from algoperf.workloads.ogbg.ogbg_pytorch.workload import \
- OgbgSiluWorkload as PyTorchWorkload
+from algoperf.workloads.ogbg.ogbg_jax.workload import (
+ OgbgSiluWorkload as JaxWorkload,
+)
+from algoperf.workloads.ogbg.ogbg_pytorch.workload import (
+ OgbgSiluWorkload as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
# Todo: refactor tests to use workload properties in cleaner way
@@ -41,8 +43,8 @@ def key_transform(k):
layer_index = int(i.split('_')[1])
if graph_network:
count = (
- graph_index * 3 * hidden_dims + seq_index * hidden_dims +
- layer_index)
+ graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index
+ )
i = 'Dense_' + str(count)
elif layer_index == 0:
i = 'node_embedding'
@@ -54,7 +56,8 @@ def key_transform(k):
elif 'LayerNorm' in i:
layer_index = int(i.split('_')[1])
count = (
- graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index)
+ graph_index * 3 * hidden_dims + seq_index * hidden_dims + layer_index
+ )
i = 'LayerNorm_' + str(count)
elif 'weight' in i:
if bn or ln:
@@ -80,13 +83,14 @@ def sd_transform(sd):
pytorch_workload = PyTorchWorkload()
pytorch_batch = dict(
- n_node=torch.LongTensor([5]),
- n_edge=torch.LongTensor([5]),
- nodes=torch.randn(5, 9),
- edges=torch.randn(5, 3),
- globals=torch.randn(1, 128),
- senders=torch.LongTensor(list(range(5))),
- receivers=torch.LongTensor([(i + 1) % 5 for i in range(5)]))
+ n_node=torch.LongTensor([5]),
+ n_edge=torch.LongTensor([5]),
+ nodes=torch.randn(5, 9),
+ edges=torch.randn(5, 3),
+ globals=torch.randn(1, 128),
+ senders=torch.LongTensor(list(range(5))),
+ receivers=torch.LongTensor([(i + 1) % 5 for i in range(5)]),
+ )
jax_batch = {k: np.array(v) for k, v in pytorch_batch.items()}
@@ -98,23 +102,26 @@ def sd_transform(sd):
pytorch_batch = {'inputs': graph_p}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=sd_transform,
- out_transform=None).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=sd_transform,
+ out_transform=None,
+ ).run()
diff --git a/tests/modeldiffs/torch2jax_utils.py b/tests/modeldiffs/torch2jax_utils.py
index d9264b400..7c77a152c 100644
--- a/tests/modeldiffs/torch2jax_utils.py
+++ b/tests/modeldiffs/torch2jax_utils.py
@@ -32,13 +32,17 @@ def flatten(jm, ret, keys=None):
def value_transform(k, value, jax_value):
k_str = ''.join(k).lower()
- if ('conv' in k_str and 'kernel' in k_str) or \
- ('embedding' in k_str and 'kernel' in k_str):
+ if ('conv' in k_str and 'kernel' in k_str) or (
+ 'embedding' in k_str and 'kernel' in k_str
+ ):
if 'transpose' in k_str:
# Assumes 2D ConvTranspose with stride equal to kernel_size.
- return value.reshape(value.shape[0], value.shape[1],
- -1).flip(-1).permute(2, 0,
- 1).reshape(*jax_value.shape)
+ return (
+ value.reshape(value.shape[0], value.shape[1], -1)
+ .flip(-1)
+ .permute(2, 0, 1)
+ .reshape(*jax_value.shape)
+ )
else:
rank = len(value.shape)
if rank == 3:
@@ -51,16 +55,17 @@ def value_transform(k, value, jax_value):
value = value.t().reshape(*list(jax_value.shape))
elif 'attention' in k_str and 'bias' in k_str:
value = value.reshape(*list(jax_value.shape))
- elif ('dense' in k_str and 'kernel' in k_str) or \
- ('lstm' in k_str and 'kernel' in k_str) or \
- ('head' in k_str and 'kernel' in k_str) or \
- ('pre_logits' in k_str and 'kernel' in k_str):
+ elif (
+ ('dense' in k_str and 'kernel' in k_str)
+ or ('lstm' in k_str and 'kernel' in k_str)
+ or ('head' in k_str and 'kernel' in k_str)
+ or ('pre_logits' in k_str and 'kernel' in k_str)
+ ):
value = value.t()
return value
class Torch2Jax:
-
def __init__(self, torch_model, jax_model):
self.torch_model = torch_model
self.jax_model = jax_model
@@ -73,13 +78,13 @@ def __init__(self, torch_model, jax_model):
def key_transform(self, k_transform_fn):
self.pytorch_sd = {
- k_transform_fn(k): self.pytorch_sd[k] for k in self.pytorch_sd
+ k_transform_fn(k): self.pytorch_sd[k] for k in self.pytorch_sd
}
def value_transform(self, v_transform_fn):
self.pytorch_sd = {
- k: v_transform_fn(k, self.pytorch_sd[k], self.flattened_jax_model[k])
- for k in self.pytorch_sd
+ k: v_transform_fn(k, self.pytorch_sd[k], self.flattened_jax_model[k])
+ for k in self.pytorch_sd
}
def sd_transform(self, sd_transform_fn):
diff --git a/tests/modeldiffs/vanilla_sgd_jax.py b/tests/modeldiffs/vanilla_sgd_jax.py
index 5595894e6..e80e70b8e 100644
--- a/tests/modeldiffs/vanilla_sgd_jax.py
+++ b/tests/modeldiffs/vanilla_sgd_jax.py
@@ -4,25 +4,30 @@
import optax
from algoperf import spec
-from reference_algorithms.target_setting_algorithms.data_selection import \
- data_selection # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.jax_submission_base import \
- update_params # pylint: disable=unused-import
+from reference_algorithms.target_setting_algorithms.data_selection import (
+ data_selection,
+) # pylint: disable=unused-import
+from reference_algorithms.target_setting_algorithms.jax_submission_base import (
+ update_params,
+) # pylint: disable=unused-import
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a Vanilla SGD Optimizer."""
del model_params
del model_state
del rng
# Create optimizer.
- params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple),
- workload.param_shapes)
+ params_zeros_like = jax.tree.map(
+ lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes
+ )
opt_init_fn, opt_update_fn = optax.sgd(learning_rate=0.001)
optimizer_state = opt_init_fn(params_zeros_like)
diff --git a/tests/modeldiffs/vanilla_sgd_pytorch.py b/tests/modeldiffs/vanilla_sgd_pytorch.py
index a6a0c5fa6..d6613479e 100644
--- a/tests/modeldiffs/vanilla_sgd_pytorch.py
+++ b/tests/modeldiffs/vanilla_sgd_pytorch.py
@@ -1,24 +1,29 @@
import torch
from algoperf import spec
-from reference_algorithms.target_setting_algorithms.data_selection import \
- data_selection # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.pytorch_submission_base import \
- update_params # pylint: disable=unused-import
+from reference_algorithms.target_setting_algorithms.data_selection import (
+ data_selection,
+) # pylint: disable=unused-import
+from reference_algorithms.target_setting_algorithms.pytorch_submission_base import (
+ update_params,
+) # pylint: disable=unused-import
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a Vanilla SGD Optimizer."""
del model_state
del rng
optimizer_state = {
- 'optimizer':
- torch.optim.SGD(model_params.parameters(), lr=0.001, weight_decay=0),
+ 'optimizer': torch.optim.SGD(
+ model_params.parameters(), lr=0.001, weight_decay=0
+ ),
}
return optimizer_state
diff --git a/tests/modeldiffs/wmt/compare.py b/tests/modeldiffs/wmt/compare.py
index 109bfa629..02175c8b5 100644
--- a/tests/modeldiffs/wmt/compare.py
+++ b/tests/modeldiffs/wmt/compare.py
@@ -8,17 +8,20 @@
from algoperf import spec
from algoperf.workloads.wmt.wmt_jax.workload import WmtWorkload as JaxWorkload
-from algoperf.workloads.wmt.wmt_pytorch.workload import \
- WmtWorkload as PyTorchWorkload
+from algoperf.workloads.wmt.wmt_pytorch.workload import (
+ WmtWorkload as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
def key_transform(k):
new_key = []
for i in k:
- if 'ModuleList' in i or\
- 'TransformerDecoder_' in i or\
- 'TransformerEncoder_' in i:
+ if (
+ 'ModuleList' in i
+ or 'TransformerDecoder_' in i
+ or 'TransformerEncoder_' in i
+ ):
continue
if 'Linear' in i:
if 'NonDynamicallyQuantizableLinear' in i:
@@ -60,7 +63,7 @@ def sd_transform(sd):
pass
else:
if new_key[-2] == 'Dense_0':
- #q
+ # q
out[(*new_key[:-2], 'query', new_key[-1])] = sd[k]
pass
elif new_key[-2] == 'Dense_1':
@@ -73,11 +76,11 @@ def sd_transform(sd):
out[(*new_key[:-2], 'out', new_key[-1])] = sd[k]
pass
out = {
- tuple(
- k.replace('SelfAttention', 'MultiHeadDotProductAttention')
- for k in key): value
- for key,
- value in out.items()
+ tuple(
+ k.replace('SelfAttention', 'MultiHeadDotProductAttention')
+ for k in key
+ ): value
+ for key, value in out.items()
}
elif 'Dense' in k_str:
new_key = (*k[:2], 'MlpBlock_0', *k[2:])
@@ -112,29 +115,32 @@ def sd_transform(sd):
tgt_tokens = torch.randint(low=0, high=32000, size=(2, 256))
jax_batch = {
- 'inputs': inp_tokens.detach().numpy(),
- 'targets': tgt_tokens.detach().numpy(),
+ 'inputs': inp_tokens.detach().numpy(),
+ 'targets': tgt_tokens.detach().numpy(),
}
pytorch_batch = {'inputs': inp_tokens, 'targets': tgt_tokens}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=sd_transform,
- out_transform=None).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=sd_transform,
+ out_transform=None,
+ ).run()
diff --git a/tests/modeldiffs/wmt_attention_temp/compare.py b/tests/modeldiffs/wmt_attention_temp/compare.py
index 1aa20fe3b..0c834dc86 100644
--- a/tests/modeldiffs/wmt_attention_temp/compare.py
+++ b/tests/modeldiffs/wmt_attention_temp/compare.py
@@ -7,19 +7,23 @@
import torch
from algoperf import spec
-from algoperf.workloads.wmt.wmt_jax.workload import \
- WmtWorkloadAttentionTemp as JaxWorkload
-from algoperf.workloads.wmt.wmt_pytorch.workload import \
- WmtWorkloadAttentionTemp as PyTorchWorkload
+from algoperf.workloads.wmt.wmt_jax.workload import (
+ WmtWorkloadAttentionTemp as JaxWorkload,
+)
+from algoperf.workloads.wmt.wmt_pytorch.workload import (
+ WmtWorkloadAttentionTemp as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
def key_transform(k):
new_key = []
for i in k:
- if 'ModuleList' in i or\
- 'TransformerDecoder_' in i or\
- 'TransformerEncoder_' in i:
+ if (
+ 'ModuleList' in i
+ or 'TransformerDecoder_' in i
+ or 'TransformerEncoder_' in i
+ ):
continue
if 'Linear' in i:
if 'NonDynamicallyQuantizableLinear' in i:
@@ -61,7 +65,7 @@ def sd_transform(sd):
pass
else:
if new_key[-2] == 'Dense_0':
- #q
+ # q
out[(*new_key[:-2], 'query', new_key[-1])] = sd[k]
pass
elif new_key[-2] == 'Dense_1':
@@ -74,11 +78,11 @@ def sd_transform(sd):
out[(*new_key[:-2], 'out', new_key[-1])] = sd[k]
pass
out = {
- tuple(
- k.replace('SelfAttention', 'MultiHeadDotProductAttention')
- for k in key): value
- for key,
- value in out.items()
+ tuple(
+ k.replace('SelfAttention', 'MultiHeadDotProductAttention')
+ for k in key
+ ): value
+ for key, value in out.items()
}
elif 'Dense' in k_str:
new_key = (*k[:2], 'MlpBlock_0', *k[2:])
@@ -113,29 +117,32 @@ def sd_transform(sd):
tgt_tokens = torch.randint(low=0, high=32000, size=(2, 256))
jax_batch = {
- 'inputs': inp_tokens.detach().numpy(),
- 'targets': tgt_tokens.detach().numpy(),
+ 'inputs': inp_tokens.detach().numpy(),
+ 'targets': tgt_tokens.detach().numpy(),
}
pytorch_batch = {'inputs': inp_tokens, 'targets': tgt_tokens}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=sd_transform,
- out_transform=None).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=sd_transform,
+ out_transform=None,
+ ).run()
diff --git a/tests/modeldiffs/wmt_glu_tanh/compare.py b/tests/modeldiffs/wmt_glu_tanh/compare.py
index e98a6945d..f7de12326 100644
--- a/tests/modeldiffs/wmt_glu_tanh/compare.py
+++ b/tests/modeldiffs/wmt_glu_tanh/compare.py
@@ -7,19 +7,23 @@
import torch
from algoperf import spec
-from algoperf.workloads.wmt.wmt_jax.workload import \
- WmtWorkloadGLUTanH as JaxWorkload
-from algoperf.workloads.wmt.wmt_pytorch.workload import \
- WmtWorkloadGLUTanH as PyTorchWorkload
+from algoperf.workloads.wmt.wmt_jax.workload import (
+ WmtWorkloadGLUTanH as JaxWorkload,
+)
+from algoperf.workloads.wmt.wmt_pytorch.workload import (
+ WmtWorkloadGLUTanH as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
def key_transform(k):
new_key = []
for i in k:
- if 'ModuleList' in i or\
- 'TransformerDecoder_' in i or\
- 'TransformerEncoder_' in i:
+ if (
+ 'ModuleList' in i
+ or 'TransformerDecoder_' in i
+ or 'TransformerEncoder_' in i
+ ):
continue
if 'Linear' in i:
if 'NonDynamicallyQuantizableLinear' in i:
@@ -61,7 +65,7 @@ def sd_transform(sd):
pass
else:
if new_key[-2] == 'Dense_0':
- #q
+ # q
out[(*new_key[:-2], 'query', new_key[-1])] = sd[k]
pass
elif new_key[-2] == 'Dense_1':
@@ -74,11 +78,11 @@ def sd_transform(sd):
out[(*new_key[:-2], 'out', new_key[-1])] = sd[k]
pass
out = {
- tuple(
- k.replace('SelfAttention', 'MultiHeadDotProductAttention')
- for k in key): value
- for key,
- value in out.items()
+ tuple(
+ k.replace('SelfAttention', 'MultiHeadDotProductAttention')
+ for k in key
+ ): value
+ for key, value in out.items()
}
elif 'Dense' in k_str:
new_key = (*k[:2], 'MlpBlock_0', *k[2:])
@@ -113,29 +117,32 @@ def sd_transform(sd):
tgt_tokens = torch.randint(low=0, high=32000, size=(2, 256))
jax_batch = {
- 'inputs': inp_tokens.detach().numpy(),
- 'targets': tgt_tokens.detach().numpy(),
+ 'inputs': inp_tokens.detach().numpy(),
+ 'targets': tgt_tokens.detach().numpy(),
}
pytorch_batch = {'inputs': inp_tokens, 'targets': tgt_tokens}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=sd_transform,
- out_transform=None).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=sd_transform,
+ out_transform=None,
+ ).run()
diff --git a/tests/modeldiffs/wmt_post_ln/compare.py b/tests/modeldiffs/wmt_post_ln/compare.py
index d110715b5..a8681ca8e 100644
--- a/tests/modeldiffs/wmt_post_ln/compare.py
+++ b/tests/modeldiffs/wmt_post_ln/compare.py
@@ -7,19 +7,23 @@
import torch
from algoperf import spec
-from algoperf.workloads.wmt.wmt_jax.workload import \
- WmtWorkloadPostLN as JaxWorkload
-from algoperf.workloads.wmt.wmt_pytorch.workload import \
- WmtWorkloadPostLN as PyTorchWorkload
+from algoperf.workloads.wmt.wmt_jax.workload import (
+ WmtWorkloadPostLN as JaxWorkload,
+)
+from algoperf.workloads.wmt.wmt_pytorch.workload import (
+ WmtWorkloadPostLN as PyTorchWorkload,
+)
from tests.modeldiffs.diff import ModelDiffRunner
def key_transform(k):
new_key = []
for i in k:
- if 'ModuleList' in i or\
- 'TransformerDecoder_' in i or\
- 'TransformerEncoder_' in i:
+ if (
+ 'ModuleList' in i
+ or 'TransformerDecoder_' in i
+ or 'TransformerEncoder_' in i
+ ):
continue
if 'Linear' in i:
if 'NonDynamicallyQuantizableLinear' in i:
@@ -61,7 +65,7 @@ def sd_transform(sd):
pass
else:
if new_key[-2] == 'Dense_0':
- #q
+ # q
out[(*new_key[:-2], 'query', new_key[-1])] = sd[k]
pass
elif new_key[-2] == 'Dense_1':
@@ -74,11 +78,11 @@ def sd_transform(sd):
out[(*new_key[:-2], 'out', new_key[-1])] = sd[k]
pass
out = {
- tuple(
- k.replace('SelfAttention', 'MultiHeadDotProductAttention')
- for k in key): value
- for key,
- value in out.items()
+ tuple(
+ k.replace('SelfAttention', 'MultiHeadDotProductAttention')
+ for k in key
+ ): value
+ for key, value in out.items()
}
elif 'Dense' in k_str:
new_key = (*k[:2], 'MlpBlock_0', *k[2:])
@@ -113,29 +117,32 @@ def sd_transform(sd):
tgt_tokens = torch.randint(low=0, high=32000, size=(2, 256))
jax_batch = {
- 'inputs': inp_tokens.detach().numpy(),
- 'targets': tgt_tokens.detach().numpy(),
+ 'inputs': inp_tokens.detach().numpy(),
+ 'targets': tgt_tokens.detach().numpy(),
}
pytorch_batch = {'inputs': inp_tokens, 'targets': tgt_tokens}
pytorch_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=pytorch_batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=pytorch_batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
jax_model_kwargs = dict(
- augmented_and_preprocessed_input_batch=jax_batch,
- mode=spec.ForwardPassMode.EVAL,
- rng=jax.random.PRNGKey(0),
- update_batch_norm=False)
+ augmented_and_preprocessed_input_batch=jax_batch,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=jax.random.PRNGKey(0),
+ update_batch_norm=False,
+ )
ModelDiffRunner(
- jax_workload=jax_workload,
- pytorch_workload=pytorch_workload,
- jax_model_kwargs=jax_model_kwargs,
- pytorch_model_kwargs=pytorch_model_kwargs,
- key_transform=key_transform,
- sd_transform=sd_transform,
- out_transform=None).run()
+ jax_workload=jax_workload,
+ pytorch_workload=pytorch_workload,
+ jax_model_kwargs=jax_model_kwargs,
+ pytorch_model_kwargs=pytorch_model_kwargs,
+ key_transform=key_transform,
+ sd_transform=sd_transform,
+ out_transform=None,
+ ).run()
diff --git a/tests/reference_algorithm_tests.py b/tests/reference_algorithm_tests.py
index 58a4a5ddc..d17848aaf 100644
--- a/tests/reference_algorithm_tests.py
+++ b/tests/reference_algorithm_tests.py
@@ -27,67 +27,66 @@
import os
import pickle
-from absl import flags
-from absl import logging
-from absl.testing import absltest
import flax
-from flax import jax_utils
-from flax.core.frozen_dict import FrozenDict
import jax
-from jraph import GraphsTuple
import numpy as np
import tensorflow as tf
import torch
import torch.distributed as dist
+from absl import flags, logging
+from absl.testing import absltest
+from flax import jax_utils
+from flax.core.frozen_dict import FrozenDict
+from jraph import GraphsTuple
-from algoperf import halton
-from algoperf import pytorch_utils
+import submission_runner
+from algoperf import halton, pytorch_utils
from algoperf import random_utils as prng
from algoperf.profiler import PassThroughProfiler
from algoperf.workloads import workloads
from algoperf.workloads.ogbg import input_pipeline as ogbg_input_pipeline
from algoperf.workloads.ogbg.ogbg_pytorch.workload import _graph_map
-import submission_runner
from tests.modeldiffs import diff as diff_utils
flags.DEFINE_integer(
- 'global_batch_size',
- -1,
- ('Global Batch size to use when running an individual workload. Otherwise '
- 'a per-device batch size of 2 is used.'))
+ 'global_batch_size',
+ -1,
+ (
+ 'Global Batch size to use when running an individual workload. Otherwise '
+ 'a per-device batch size of 2 is used.'
+ ),
+)
flags.DEFINE_integer('num_train_steps', 1, 'Number of steps to train.')
flags.DEFINE_boolean('use_fake_input_queue', True, 'Use fake data examples.')
flags.DEFINE_string('log_file', '/tmp/log.pkl', 'The log file')
flags.DEFINE_boolean(
- 'all',
- False,
- 'Run all workloads instead of using --workload and --framework.')
-flags.DEFINE_boolean('identical',
- False,
- 'Run jax and pytorch with identical weights.')
+ 'all', False, 'Run all workloads instead of using --workload and --framework.'
+)
+flags.DEFINE_boolean(
+ 'identical', False, 'Run jax and pytorch with identical weights.'
+)
FLAGS = flags.FLAGS
USE_PYTORCH_DDP, RANK, PYTORCH_DEVICE, N_GPUS = pytorch_utils.pytorch_setup()
tf.config.set_visible_devices([], 'GPU')
_EXPECTED_METRIC_NAMES = {
- 'cifar': ['train/loss', 'validation/loss', 'test/accuracy'],
- 'criteo1tb': ['train/loss', 'validation/loss'],
- 'criteo1tb_test': ['train/loss', 'validation/loss'],
- 'fastmri': ['train/ssim', 'validation/ssim'],
- 'imagenet_resnet': ['train/accuracy', 'validation/accuracy'],
- 'imagenet_vit': ['train/accuracy', 'validation/accuracy'],
- 'librispeech_conformer': ['train/wer', 'validation/wer', 'train/ctc_loss'],
- 'librispeech_deepspeech': ['train/wer', 'validation/wer', 'train/ctc_loss'],
- 'mnist': ['train/loss', 'validation/accuracy', 'test/accuracy'],
- 'ogbg': [
- 'train/accuracy', 'validation/loss', 'test/mean_average_precision'
- ],
- 'wmt': ['train/bleu', 'validation/loss', 'validation/accuracy'],
+ 'cifar': ['train/loss', 'validation/loss', 'test/accuracy'],
+ 'criteo1tb': ['train/loss', 'validation/loss'],
+ 'criteo1tb_test': ['train/loss', 'validation/loss'],
+ 'fastmri': ['train/ssim', 'validation/ssim'],
+ 'imagenet_resnet': ['train/accuracy', 'validation/accuracy'],
+ 'imagenet_vit': ['train/accuracy', 'validation/accuracy'],
+ 'librispeech_conformer': ['train/wer', 'validation/wer', 'train/ctc_loss'],
+ 'librispeech_deepspeech': ['train/wer', 'validation/wer', 'train/ctc_loss'],
+ 'mnist': ['train/loss', 'validation/accuracy', 'test/accuracy'],
+ 'ogbg': ['train/accuracy', 'validation/loss', 'test/mean_average_precision'],
+ 'wmt': ['train/bleu', 'validation/loss', 'validation/accuracy'],
}
def _make_fake_image_batch(batch_shape, data_shape, num_classes):
- examples = np.random.normal(size=(*batch_shape,
- *data_shape)).astype(np.float32)
+ examples = np.random.normal(size=(*batch_shape, *data_shape)).astype(
+ np.float32
+ )
labels = np.random.randint(0, num_classes, size=batch_shape)
masks = np.ones(batch_shape, dtype=np.float32)
return {'inputs': examples, 'targets': labels, 'weights': masks}
@@ -96,16 +95,17 @@ def _make_fake_image_batch(batch_shape, data_shape, num_classes):
def _pytorch_map(inputs):
if USE_PYTORCH_DDP:
return jax.tree.map(
- lambda a: torch.as_tensor(a[RANK], device=PYTORCH_DEVICE), inputs)
+ lambda a: torch.as_tensor(a[RANK], device=PYTORCH_DEVICE), inputs
+ )
return jax.tree.map(
- lambda a: torch.as_tensor(a, device=PYTORCH_DEVICE).view(-1, a.shape[-1])
- if len(a.shape) == 3 else torch.as_tensor(a, device=PYTORCH_DEVICE).view(
- -1),
- inputs)
+ lambda a: torch.as_tensor(a, device=PYTORCH_DEVICE).view(-1, a.shape[-1])
+ if len(a.shape) == 3
+ else torch.as_tensor(a, device=PYTORCH_DEVICE).view(-1),
+ inputs,
+ )
class _FakeTokenizer:
-
def detokenize(self, *args):
del args
return tf.constant('this is a fake sequence?')
@@ -113,15 +113,14 @@ def detokenize(self, *args):
@flax.struct.dataclass
class _FakeMetricsCollection:
-
def merge(self, *args):
del args
return self
def compute(self):
return {
- 'wer': 0.0,
- 'ctc_loss': 0.0,
+ 'wer': 0.0,
+ 'ctc_loss': 0.0,
}
def unreplicate(self):
@@ -129,7 +128,6 @@ def unreplicate(self):
class _FakeMetricsLogger:
-
def __init__(self):
self.filename = FLAGS.log_file
self.scalars = []
@@ -152,27 +150,27 @@ def append_eval_metrics(self, result):
def save(self):
with open(self.filename, 'wb') as f:
- pickle.dump({'scalars': self.scalars, 'eval_results': self.eval_results},
- f)
+ pickle.dump(
+ {'scalars': self.scalars, 'eval_results': self.eval_results}, f
+ )
class _FakeMetricsBundle:
-
def gather_from_model_output(self, *args, **kwargs):
del args
del kwargs
return _FakeMetricsCollection()
-def _make_one_batch_workload(workload_class,
- workload_name,
- framework,
- global_batch_size,
- use_fake_input_queue,
- n_gpus):
-
+def _make_one_batch_workload(
+ workload_class,
+ workload_name,
+ framework,
+ global_batch_size,
+ use_fake_input_queue,
+ n_gpus,
+):
class _OneEvalBatchWorkload(workload_class):
-
def __init__(self):
kwargs = {}
if 'librispeech' in workload_name:
@@ -186,20 +184,27 @@ def __init__(self):
def init_model_fn(self, rng):
# pylint: disable=line-too-long
- if not (FLAGS.identical and
- os.path.exists(f'tests/modeldiffs/{workload_name}/compare.py')):
+ if not (
+ FLAGS.identical
+ and os.path.exists(f'tests/modeldiffs/{workload_name}/compare.py')
+ ):
return super().init_model_fn(rng)
if framework == 'jax':
compare_module = importlib.import_module(
- f'tests.modeldiffs.{workload_name}.compare')
+ f'tests.modeldiffs.{workload_name}.compare'
+ )
jax_params, model_state, _ = diff_utils.torch2jax(
jax_workload=super(),
pytorch_workload=compare_module.PyTorchWorkload(**self.init_kwargs),
key_transform=compare_module.key_transform,
- sd_transform=compare_module.sd_transform)
- return (FrozenDict(**jax_utils.replicate(jax_params)),
- FrozenDict(**jax_utils.replicate(model_state))
- if model_state is not None else model_state)
+ sd_transform=compare_module.sd_transform,
+ )
+ return (
+ FrozenDict(**jax_utils.replicate(jax_params)),
+ FrozenDict(**jax_utils.replicate(model_state))
+ if model_state is not None
+ else model_state,
+ )
return super().init_model_fn([0])
@property
@@ -235,73 +240,74 @@ def _build_input_queue(self, *args, **kwargs):
else:
data_shape = (3, 32, 32)
fake_batch = _make_fake_image_batch(
- batch_shape, data_shape=data_shape, num_classes=10)
+ batch_shape, data_shape=data_shape, num_classes=10
+ )
elif workload_name == 'criteo1tb' or workload_name == 'criteo1tb_test':
targets = np.ones(batch_shape)
targets[0] = 0
fake_batch = {
- 'inputs': np.ones((*batch_shape, 13 + 26)),
- 'targets': targets,
- 'weights': np.ones(batch_shape),
+ 'inputs': np.ones((*batch_shape, 13 + 26)),
+ 'targets': targets,
+ 'weights': np.ones(batch_shape),
}
elif workload_name in ['imagenet_resnet', 'imagenet_vit']:
data_shape = (224, 224, 3)
fake_batch = _make_fake_image_batch(
- batch_shape, data_shape=data_shape, num_classes=1000)
+ batch_shape, data_shape=data_shape, num_classes=1000
+ )
if framework == 'pytorch':
num_dims = len(fake_batch['inputs'].shape)
fake_batch['inputs'] = fake_batch['inputs'].transpose(
- (*range(num_dims - 3), num_dims - 1, num_dims - 3, num_dims - 2))
+ (*range(num_dims - 3), num_dims - 1, num_dims - 3, num_dims - 2)
+ )
elif 'librispeech' in workload_name:
rate = 16000
- l = None
- while l is None or l.shape[-1] < 320000:
+ audio_signal = None
+ while audio_signal is None or audio_signal.shape[-1] < 320000:
duration = 0.5
- freq = 2**(np.random.rand(*batch_shape, 1) * 13)
+ freq = 2 ** (np.random.rand(*batch_shape, 1) * 13)
wav = np.sin(2 * np.pi * freq * np.arange(rate * duration) / rate)
- if l is None:
- l = wav
+ if audio_signal is None:
+ audio_signal = wav
else:
- l = np.concatenate([l, wav], axis=-1)
- inputs = l
+ audio_signal = np.concatenate([audio_signal, wav], axis=-1)
+ inputs = audio_signal
targets = np.random.randint(low=1, high=1024, size=(*batch_shape, 256))
tgt_pad = np.arange(0, 256)[tuple([None] * len(batch_shape))]
tgt_lengths = np.random.randint(
- low=100, high=256, size=(*batch_shape, 1))
+ low=100, high=256, size=(*batch_shape, 1)
+ )
tgt_pad = 1 * (tgt_pad > tgt_lengths)
fake_batch = {
- 'inputs': (inputs, np.zeros_like(inputs)),
- 'targets': (targets, tgt_pad),
+ 'inputs': (inputs, np.zeros_like(inputs)),
+ 'targets': (targets, tgt_pad),
}
elif workload_name == 'mnist':
fake_batch = _make_fake_image_batch(
- batch_shape, data_shape=(28, 28, 1), num_classes=10)
+ batch_shape, data_shape=(28, 28, 1), num_classes=10
+ )
elif workload_name == 'ogbg':
tf.random.set_seed(5)
def _fake_iter():
while True:
fake_batch = {
- 'num_nodes':
- tf.ones((1,), dtype=tf.int64),
- 'edge_index':
- tf.ones((1, 2), dtype=tf.int64),
- 'node_feat':
- tf.random.normal((1, 9)),
- 'edge_feat':
- tf.random.normal((1, 3)),
- 'labels':
- tf.cast(
- tf.random.uniform((self._num_outputs,),
- minval=0,
- maxval=2,
- dtype=tf.int32),
- tf.float32),
+ 'num_nodes': tf.ones((1,), dtype=tf.int64),
+ 'edge_index': tf.ones((1, 2), dtype=tf.int64),
+ 'node_feat': tf.random.normal((1, 9)),
+ 'edge_feat': tf.random.normal((1, 3)),
+ 'labels': tf.cast(
+ tf.random.uniform(
+ (self._num_outputs,), minval=0, maxval=2, dtype=tf.int32
+ ),
+ tf.float32,
+ ),
}
yield fake_batch
fake_batch_iter = ogbg_input_pipeline._get_batch_iterator(
- _fake_iter(), global_batch_size)
+ _fake_iter(), global_batch_size
+ )
fake_batch = next(fake_batch_iter) # pylint: disable=stop-iteration-return
if framework == 'pytorch':
fake_batch['inputs'] = _graph_map(_pytorch_map, fake_batch['inputs'])
@@ -310,48 +316,49 @@ def _fake_iter():
elif workload_name == 'wmt':
max_len = 256
fake_batch = {
- 'inputs':
- np.random.randint(
- low=0, high=32000, size=(*batch_shape, max_len)),
- 'targets':
- np.random.randint(
- low=0, high=32000, size=(*batch_shape, max_len)),
- 'weights':
- np.random.randint(low=0, high=2, size=(*batch_shape, max_len)),
+ 'inputs': np.random.randint(
+ low=0, high=32000, size=(*batch_shape, max_len)
+ ),
+ 'targets': np.random.randint(
+ low=0, high=32000, size=(*batch_shape, max_len)
+ ),
+ 'weights': np.random.randint(
+ low=0, high=2, size=(*batch_shape, max_len)
+ ),
}
self._tokenizer = _FakeTokenizer()
elif workload_name == 'fastmri':
data_shape = (320, 320)
fake_batch = {
- 'inputs':
- _make_fake_image_batch(
- batch_shape, data_shape=data_shape, num_classes=1000)
- ['inputs'],
- 'targets':
- _make_fake_image_batch(
- batch_shape, data_shape=data_shape, num_classes=1000)
- ['inputs'],
- 'mean':
- np.zeros(batch_shape),
- 'std':
- np.ones(batch_shape),
- 'volume_max':
- np.zeros(batch_shape),
- 'weights':
- np.ones(batch_shape),
+ 'inputs': _make_fake_image_batch(
+ batch_shape, data_shape=data_shape, num_classes=1000
+ )['inputs'],
+ 'targets': _make_fake_image_batch(
+ batch_shape, data_shape=data_shape, num_classes=1000
+ )['inputs'],
+ 'mean': np.zeros(batch_shape),
+ 'std': np.ones(batch_shape),
+ 'volume_max': np.zeros(batch_shape),
+ 'weights': np.ones(batch_shape),
}
else:
raise ValueError(
- 'Workload {} does not have a fake batch defined, you '
- 'can add it or use --use_fake_input_queue=false.'.format(
- workload_name))
+ 'Workload {} does not have a fake batch defined, you '
+ 'can add it or use --use_fake_input_queue=false.'.format(
+ workload_name
+ )
+ )
if framework == 'pytorch':
def to_device(k, v):
dtype = (
- torch.long if (k == 'targets' and workload_name != 'fastmri') else
- torch.bool if k == 'weights' else torch.float)
+ torch.long
+ if (k == 'targets' and workload_name != 'fastmri')
+ else torch.bool
+ if k == 'weights'
+ else torch.float
+ )
if USE_PYTORCH_DDP:
v = v[RANK]
return torch.as_tensor(v, device=PYTORCH_DEVICE, dtype=dtype)
@@ -387,24 +394,28 @@ def eval_model(self, *args, **kwargs):
return _OneEvalBatchWorkload()
-def _test_submission(workload_name,
- framework,
- submission_path,
- search_space_path,
- data_dir,
- use_fake_input_queue,
- n_gpus):
+def _test_submission(
+ workload_name,
+ framework,
+ submission_path,
+ search_space_path,
+ data_dir,
+ use_fake_input_queue,
+ n_gpus,
+):
logging.info(f'========= Testing {workload_name} in {framework}.')
FLAGS.framework = framework
workload_metadata = copy.deepcopy(submission_runner.WORKLOADS[workload_name])
workload_metadata['workload_path'] = os.path.join(
- submission_runner.BASE_WORKLOADS_DIR,
- workload_metadata['workload_path'] + '_' + framework,
- 'workload.py')
+ submission_runner.BASE_WORKLOADS_DIR,
+ workload_metadata['workload_path'] + '_' + framework,
+ 'workload.py',
+ )
workload_class = workloads.import_workload(
- workload_path=workload_metadata['workload_path'],
- workload_class_name=workload_metadata['workload_class_name'],
- return_class=True)
+ workload_path=workload_metadata['workload_path'],
+ workload_class_name=workload_metadata['workload_class_name'],
+ return_class=True,
+ )
print(f'Workload class for {workload_name} is {workload_class}')
submission_module_path = workloads.convert_filepath_to_module(submission_path)
@@ -421,30 +432,32 @@ def _test_submission(workload_name,
global_batch_size = FLAGS.global_batch_size
if FLAGS.global_batch_size < 0:
raise ValueError('Must set --global_batch_size.')
- workload = _make_one_batch_workload(workload_class,
- workload_name,
- framework,
- global_batch_size,
- use_fake_input_queue,
- n_gpus)
+ workload = _make_one_batch_workload(
+ workload_class,
+ workload_name,
+ framework,
+ global_batch_size,
+ use_fake_input_queue,
+ n_gpus,
+ )
# Get a sample hyperparameter setting.
hyperparameters = {}
if search_space_path != 'None':
with open(search_space_path, 'r', encoding='UTF-8') as search_space_file:
hyperparameters = halton.generate_search(
- json.load(search_space_file), num_trials=1)[0]
+ json.load(search_space_file), num_trials=1
+ )[0]
rng = prng.PRNGKey(0)
data_rng, opt_init_rng, model_init_rng, rng = prng.split(rng, 4)
input_queue = workload._build_input_queue(
- data_rng, 'train', data_dir=data_dir, global_batch_size=global_batch_size)
+ data_rng, 'train', data_dir=data_dir, global_batch_size=global_batch_size
+ )
model_params, model_state = workload.init_model_fn(model_init_rng)
- optimizer_state = init_optimizer_state(workload,
- model_params,
- model_state,
- hyperparameters,
- opt_init_rng)
+ optimizer_state = init_optimizer_state(
+ workload, model_params, model_state, hyperparameters, opt_init_rng
+ )
if USE_PYTORCH_DDP:
torch.cuda.empty_cache()
@@ -452,44 +465,49 @@ def _test_submission(workload_name,
for global_step in range(FLAGS.num_train_steps):
step_rng = prng.fold_in(rng, global_step)
data_select_rng, update_rng, eval_rng = prng.split(step_rng, 3)
- batch = data_selection(workload,
- input_queue,
- optimizer_state,
- model_params,
- model_state,
- hyperparameters,
- global_step,
- data_select_rng)
+ batch = data_selection(
+ workload,
+ input_queue,
+ optimizer_state,
+ model_params,
+ model_state,
+ hyperparameters,
+ global_step,
+ data_select_rng,
+ )
optimizer_state, model_params, model_state = update_params(
- workload=workload,
- current_param_container=model_params,
- current_params_types=workload.model_params_types,
- model_state=model_state,
- hyperparameters=hyperparameters,
- batch=batch,
- loss_type=workload.loss_type,
- optimizer_state=optimizer_state,
- train_state={},
- eval_results=[],
- global_step=global_step,
- rng=update_rng)
+ workload=workload,
+ current_param_container=model_params,
+ current_params_types=workload.model_params_types,
+ model_state=model_state,
+ hyperparameters=hyperparameters,
+ batch=batch,
+ loss_type=workload.loss_type,
+ optimizer_state=optimizer_state,
+ train_state={},
+ eval_results=[],
+ global_step=global_step,
+ rng=update_rng,
+ )
eval_result = workload.eval_model(
- global_batch_size,
- model_params,
- model_state,
- eval_rng,
- data_dir,
- imagenet_v2_data_dir=None,
- global_step=global_step)
- _ = workload.eval_model(
global_batch_size,
model_params,
model_state,
eval_rng,
data_dir,
imagenet_v2_data_dir=None,
- global_step=global_step)
+ global_step=global_step,
+ )
+ _ = workload.eval_model(
+ global_batch_size,
+ model_params,
+ model_state,
+ eval_rng,
+ data_dir,
+ imagenet_v2_data_dir=None,
+ global_step=global_step,
+ )
return eval_result
@@ -499,12 +517,15 @@ def _make_paths(repo_location, framework, workload_name):
else:
dataset_name = workload_name
workload_dir = (
- f'{repo_location}/reference_algorithms/target_setting_algorithms/'
- f'{workload_name}')
+ f'{repo_location}/reference_algorithms/target_setting_algorithms/'
+ f'{workload_name}'
+ )
search_space_path = f'{workload_dir}/tuning_search_space.json'
- submission_path = (f'reference_algorithms/target_setting_algorithms/'
- f'{workload_name}/{dataset_name}_{framework}/'
- 'submission.py')
+ submission_path = (
+ f'reference_algorithms/target_setting_algorithms/'
+ f'{workload_name}/{dataset_name}_{framework}/'
+ 'submission.py'
+ )
full_submission_path = f'{repo_location}/{submission_path}'
if not os.path.exists(full_submission_path):
return None, None
@@ -534,7 +555,8 @@ def test_submission(self):
if FLAGS.tuning_search_space:
raise ValueError('Cannot set --tuning_search_space and --all.')
references_dir = (
- f'{repo_location}/reference_algorithms/target_setting_algorithms')
+ f'{repo_location}/reference_algorithms/target_setting_algorithms'
+ )
for workload_name in os.listdir(references_dir):
for framework in ['jax', 'pytorch']:
if framework == 'pytorch':
@@ -542,17 +564,19 @@ def test_submission(self):
# First jax operation has to be called after pytorch_init.
n_gpus = max(N_GPUS, jax.local_device_count())
search_space_path, submission_path = _make_paths(
- repo_location, framework, workload_name)
+ repo_location, framework, workload_name
+ )
if search_space_path is None:
continue
eval_result = _test_submission(
- workload_name,
- framework,
- submission_path,
- search_space_path,
- data_dir=FLAGS.data_dir,
- use_fake_input_queue=FLAGS.use_fake_input_queue,
- n_gpus=n_gpus)
+ workload_name,
+ framework,
+ submission_path,
+ search_space_path,
+ data_dir=FLAGS.data_dir,
+ use_fake_input_queue=FLAGS.use_fake_input_queue,
+ n_gpus=n_gpus,
+ )
self._assert_eval_result(workload_name, eval_result)
else:
framework = FLAGS.framework
@@ -566,15 +590,17 @@ def test_submission(self):
submission_path = FLAGS.submission_path
else:
search_space_path, submission_path = _make_paths(
- repo_location, framework, workload_name)
+ repo_location, framework, workload_name
+ )
eval_result = _test_submission(
- workload_name,
- framework,
- submission_path,
- search_space_path,
- data_dir=FLAGS.data_dir,
- use_fake_input_queue=FLAGS.use_fake_input_queue,
- n_gpus=n_gpus)
+ workload_name,
+ framework,
+ submission_path,
+ search_space_path,
+ data_dir=FLAGS.data_dir,
+ use_fake_input_queue=FLAGS.use_fake_input_queue,
+ n_gpus=n_gpus,
+ )
self._assert_eval_result(workload_name, eval_result)
if USE_PYTORCH_DDP:
diff --git a/tests/submission_runner_test.py b/tests/submission_runner_test.py
index ff724b201..b9beb9101 100644
--- a/tests/submission_runner_test.py
+++ b/tests/submission_runner_test.py
@@ -4,6 +4,7 @@
dataset to be available. For testing the workload and reference submission code
for all workloads, see reference_algorithm_tests.py.
"""
+
import copy
import os
import sys
@@ -28,47 +29,46 @@ class SubmissionRunnerTest(parameterized.TestCase):
"""Tests for reference submissions."""
@parameterized.named_parameters(
- dict(
- testcase_name='mnist_jax',
- workload='mnist',
- framework='jax',
- submission_path=(f'{_MNIST_DEV_ALGO_DIR}/mnist_jax/submission.py'),
- tuning_search_space=(
- f'{_MNIST_DEV_ALGO_DIR}/tuning_search_space.json')),
- dict(
- testcase_name='mnist_pytorch',
- workload='mnist',
- framework='pytorch',
- submission_path=(
- f'{_MNIST_DEV_ALGO_DIR}/mnist_pytorch/submission.py'),
- tuning_search_space=(
- f'{_MNIST_DEV_ALGO_DIR}/tuning_search_space.json')),
+ dict(
+ testcase_name='mnist_jax',
+ workload='mnist',
+ framework='jax',
+ submission_path=(f'{_MNIST_DEV_ALGO_DIR}/mnist_jax/submission.py'),
+ tuning_search_space=(f'{_MNIST_DEV_ALGO_DIR}/tuning_search_space.json'),
+ ),
+ dict(
+ testcase_name='mnist_pytorch',
+ workload='mnist',
+ framework='pytorch',
+ submission_path=(f'{_MNIST_DEV_ALGO_DIR}/mnist_pytorch/submission.py'),
+ tuning_search_space=(f'{_MNIST_DEV_ALGO_DIR}/tuning_search_space.json'),
+ ),
)
- def test_submission(self,
- workload,
- framework,
- submission_path,
- tuning_search_space):
+ def test_submission(
+ self, workload, framework, submission_path, tuning_search_space
+ ):
FLAGS.framework = framework
workload_metadata = copy.deepcopy(submission_runner.WORKLOADS[workload])
workload_metadata['workload_path'] = os.path.join(
- submission_runner.BASE_WORKLOADS_DIR,
- workload_metadata['workload_path'] + '_' + framework,
- 'workload.py')
+ submission_runner.BASE_WORKLOADS_DIR,
+ workload_metadata['workload_path'] + '_' + framework,
+ 'workload.py',
+ )
workload_obj = submission_runner.import_workload(
- workload_path=workload_metadata['workload_path'],
- workload_class_name=workload_metadata['workload_class_name'],
- workload_init_kwargs={})
+ workload_path=workload_metadata['workload_path'],
+ workload_class_name=workload_metadata['workload_class_name'],
+ workload_init_kwargs={},
+ )
score = submission_runner.score_submission_on_workload(
- workload_obj,
- workload,
- submission_path,
- data_dir='~/tensorflow_datasets', # The default in TFDS.
- tuning_ruleset='external',
- tuning_search_space=tuning_search_space,
- num_tuning_trials=1,
- profiler=PassThroughProfiler(),
- max_global_steps=500,
+ workload_obj,
+ workload,
+ submission_path,
+ data_dir='~/tensorflow_datasets', # The default in TFDS.
+ tuning_ruleset='external',
+ tuning_search_space=tuning_search_space,
+ num_tuning_trials=1,
+ profiler=PassThroughProfiler(),
+ max_global_steps=500,
)
logging.info(score)
diff --git a/tests/test_baselines.py b/tests/test_baselines.py
index b2be8aa11..9ebc50222 100644
--- a/tests/test_baselines.py
+++ b/tests/test_baselines.py
@@ -1,8 +1,9 @@
"""Tests for submission.py for baselines.
-This is an end-to-end test for all baselines on MNIST in PyTorch and Jax that
+This is an end-to-end test for all baselines on MNIST in PyTorch and Jax that
requires the dataset to be available.
"""
+
import copy
import os
import sys
@@ -24,41 +25,42 @@
MAX_GLOBAL_STEPS = 5
baselines = {
- 'jax': [
- 'adafactor',
- 'adamw',
- 'lamb',
- 'momentum',
- 'nadamw',
- 'nesterov',
- 'sam',
- 'shampoo',
- ],
- 'pytorch': [
- 'adamw',
- 'momentum',
- 'nadamw',
- 'nesterov',
- ],
+ 'jax': [
+ 'adafactor',
+ 'adamw',
+ 'lamb',
+ 'momentum',
+ 'nadamw',
+ 'nesterov',
+ 'sam',
+ 'shampoo',
+ ],
+ 'pytorch': [
+ 'adamw',
+ 'momentum',
+ 'nadamw',
+ 'nesterov',
+ ],
}
frameworks = [
- 'pytorch',
- 'jax',
+ 'pytorch',
+ 'jax',
]
-baseline_path = "reference_algorithms/paper_baselines"
+baseline_path = 'reference_algorithms/paper_baselines'
named_parameters = []
for f in frameworks:
for b in baselines[f]:
named_parameters.append(
- dict(
- testcase_name=f'{b}_{f}',
- workload='mnist',
- framework=f'{f}',
- submission_path=f'{baseline_path}/{b}/{f}/submission.py',
- tuning_search_space=f'{baseline_path}/{b}/tuning_search_space.json')
+ dict(
+ testcase_name=f'{b}_{f}',
+ workload='mnist',
+ framework=f'{f}',
+ submission_path=f'{baseline_path}/{b}/{f}/submission.py',
+ tuning_search_space=f'{baseline_path}/{b}/tuning_search_space.json',
+ )
)
@@ -66,31 +68,31 @@ class BaselineTest(parameterized.TestCase):
"""Tests for reference submissions."""
@parameterized.named_parameters(*named_parameters)
- def test_baseline_submission(self,
- workload,
- framework,
- submission_path,
- tuning_search_space):
+ def test_baseline_submission(
+ self, workload, framework, submission_path, tuning_search_space
+ ):
FLAGS.framework = framework
workload_metadata = copy.deepcopy(workloads.WORKLOADS[workload])
workload_metadata['workload_path'] = os.path.join(
- workloads.BASE_WORKLOADS_DIR,
- workload_metadata['workload_path'] + '_' + framework,
- 'workload.py')
+ workloads.BASE_WORKLOADS_DIR,
+ workload_metadata['workload_path'] + '_' + framework,
+ 'workload.py',
+ )
workload_obj = workloads.import_workload(
- workload_path=workload_metadata['workload_path'],
- workload_class_name=workload_metadata['workload_class_name'],
- workload_init_kwargs={})
+ workload_path=workload_metadata['workload_path'],
+ workload_class_name=workload_metadata['workload_class_name'],
+ workload_init_kwargs={},
+ )
score = submission_runner.score_submission_on_workload(
- workload_obj,
- workload,
- submission_path,
- data_dir='~/tensorflow_datasets', # The default in TFDS.
- tuning_ruleset='external',
- tuning_search_space=tuning_search_space,
- num_tuning_trials=1,
- profiler=PassThroughProfiler(),
- max_global_steps=MAX_GLOBAL_STEPS,
+ workload_obj,
+ workload,
+ submission_path,
+ data_dir='~/tensorflow_datasets', # The default in TFDS.
+ tuning_ruleset='external',
+ tuning_search_space=tuning_search_space,
+ num_tuning_trials=1,
+ profiler=PassThroughProfiler(),
+ max_global_steps=MAX_GLOBAL_STEPS,
)
logging.info(score)
diff --git a/tests/test_num_params.py b/tests/test_num_params.py
index b0633025e..9361f4c72 100644
--- a/tests/test_num_params.py
+++ b/tests/test_num_params.py
@@ -5,48 +5,59 @@
import pytest
import torch
-from algoperf.workloads.criteo1tb.criteo1tb_jax.models import \
- DlrmSmall as JaxDlrmSmall
-from algoperf.workloads.criteo1tb.criteo1tb_pytorch.models import \
- DlrmSmall as PyTorchDlrmSmall
-from algoperf.workloads.imagenet_resnet.imagenet_jax.models import \
- ResNet18 as JaxResNet_c10
-from algoperf.workloads.imagenet_resnet.imagenet_jax.models import \
- ResNet50 as JaxResNet
-from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import \
- resnet18 as PyTorchResNet_c10
-from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import \
- resnet50 as PyTorchResNet
+from algoperf.workloads.criteo1tb.criteo1tb_jax.models import (
+ DlrmSmall as JaxDlrmSmall,
+)
+from algoperf.workloads.criteo1tb.criteo1tb_pytorch.models import (
+ DlrmSmall as PyTorchDlrmSmall,
+)
+from algoperf.workloads.imagenet_resnet.imagenet_jax.models import (
+ ResNet18 as JaxResNet_c10,
+)
+from algoperf.workloads.imagenet_resnet.imagenet_jax.models import (
+ ResNet50 as JaxResNet,
+)
+from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import (
+ resnet18 as PyTorchResNet_c10,
+)
+from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import (
+ resnet50 as PyTorchResNet,
+)
from algoperf.workloads.imagenet_vit.imagenet_jax.models import ViT as JaxViT
-from algoperf.workloads.imagenet_vit.imagenet_pytorch.models import \
- ViT as PyTorchViT
-from algoperf.workloads.librispeech_conformer.librispeech_jax.models import \
- Conformer as JaxConformer
-from algoperf.workloads.librispeech_conformer.librispeech_jax.models import \
- ConformerConfig as JaxConformerConfig
-from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \
- ConformerConfig as PytorchConformerConfig
-from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \
- ConformerEncoderDecoder as PytorchConformer
+from algoperf.workloads.imagenet_vit.imagenet_pytorch.models import (
+ ViT as PyTorchViT,
+)
+from algoperf.workloads.librispeech_conformer.librispeech_jax.models import (
+ Conformer as JaxConformer,
+)
+from algoperf.workloads.librispeech_conformer.librispeech_jax.models import (
+ ConformerConfig as JaxConformerConfig,
+)
+from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import (
+ ConformerConfig as PytorchConformerConfig,
+)
+from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import (
+ ConformerEncoderDecoder as PytorchConformer,
+)
from algoperf.workloads.mnist.mnist_jax.workload import _Model as JaxMLP
-from algoperf.workloads.mnist.mnist_pytorch.workload import \
- _Model as PyTorchMLP
+from algoperf.workloads.mnist.mnist_pytorch.workload import _Model as PyTorchMLP
from algoperf.workloads.ogbg.ogbg_jax.models import GNN as JaxGNN
from algoperf.workloads.ogbg.ogbg_pytorch.models import GNN as PyTorchGNN
from algoperf.workloads.wmt.wmt_jax.models import Transformer as JaxTransformer
from algoperf.workloads.wmt.wmt_jax.models import TransformerConfig
-from algoperf.workloads.wmt.wmt_pytorch.models import \
- Transformer as PyTorchTransformer
+from algoperf.workloads.wmt.wmt_pytorch.models import (
+ Transformer as PyTorchTransformer,
+)
WORKLOADS = [
- 'mnist',
- 'cifar',
- 'criteo1tb',
- 'imagenet_resnet',
- 'imagenet_vit',
- 'wmt',
- 'ogbg',
- 'librispeech_conformer',
+ 'mnist',
+ 'cifar',
+ 'criteo1tb',
+ 'imagenet_resnet',
+ 'imagenet_vit',
+ 'wmt',
+ 'ogbg',
+ 'librispeech_conformer',
]
@@ -56,7 +67,8 @@ def test_matching_num_params(workload):
# Count parameters of both models.
num_jax_params = sum(x.size for x in jax.tree_util.tree_leaves(jax_model))
num_pytorch_params = sum(
- p.numel() for p in pytorch_model.parameters() if p.requires_grad)
+ p.numel() for p in pytorch_model.parameters() if p.requires_grad
+ )
assert num_jax_params == num_pytorch_params
@@ -72,8 +84,9 @@ def get_models(workload):
# Init Jax model.
input_shape = (1, 32, 32, 3)
model_init = jax.jit(JaxResNet_c10(num_classes=10, dtype=jnp.float32).init)
- jax_model = model_init(init_rngs, jnp.ones(input_shape,
- jnp.float32))["params"]
+ jax_model = model_init(init_rngs, jnp.ones(input_shape, jnp.float32))[
+ 'params'
+ ]
# Init PyTorch model.
pytorch_model = PyTorchResNet_c10(num_classes=10)
@@ -85,35 +98,38 @@ def get_models(workload):
vocab_size = 32 * 128 * 1024
input_shape = (1, 39)
model_init = JaxDlrmSmall(
- vocab_size=vocab_size,
- num_dense_features=13,
- mlp_bottom_dims=mlp_bottom_dims,
- mlp_top_dims=mlp_top_dims,
- embed_dim=embed_dim).init
- jax_model = model_init(init_rngs, jnp.ones(input_shape, jnp.float32),
- False)['params']
+ vocab_size=vocab_size,
+ num_dense_features=13,
+ mlp_bottom_dims=mlp_bottom_dims,
+ mlp_top_dims=mlp_top_dims,
+ embed_dim=embed_dim,
+ ).init
+ jax_model = model_init(
+ init_rngs, jnp.ones(input_shape, jnp.float32), False
+ )['params']
# Init PyTorch model.
pytorch_model = PyTorchDlrmSmall(
- vocab_size=vocab_size,
- num_dense_features=13,
- mlp_bottom_dims=mlp_bottom_dims,
- mlp_top_dims=mlp_top_dims,
- embed_dim=embed_dim)
+ vocab_size=vocab_size,
+ num_dense_features=13,
+ mlp_bottom_dims=mlp_bottom_dims,
+ mlp_top_dims=mlp_top_dims,
+ embed_dim=embed_dim,
+ )
elif workload == 'imagenet_resnet':
# Init Jax model.
input_shape = (1, 224, 224, 3)
- jax_model = JaxResNet(
- num_classes=1000,
- dtype=jnp.float32).init(init_rngs, jnp.ones(input_shape,
- jnp.float32))['params']
+ jax_model = JaxResNet(num_classes=1000, dtype=jnp.float32).init(
+ init_rngs, jnp.ones(input_shape, jnp.float32)
+ )['params']
# Init PyTorch model.
pytorch_model = PyTorchResNet()
elif workload == 'imagenet_vit':
# Init Jax model.
input_shape = (1, 224, 224, 3)
jax_model = JaxViT(num_classes=1000).init(
- init_rngs, jnp.ones(input_shape, jnp.float32))['params']
+ init_rngs, jnp.ones(input_shape, jnp.float32)
+ )['params']
# Init PyTorch model.
pytorch_model = PyTorchViT()
elif workload == 'librispeech_conformer':
@@ -123,8 +139,9 @@ def get_models(workload):
# Init Jax model
input_shape = [(320000,), (320000,)]
fake_input_batch = [jnp.zeros((2, *x), jnp.float32) for x in input_shape]
- jax_model = jax_model.init(
- init_rngs, train=False, *fake_input_batch)["params"]
+ jax_model = jax_model.init(init_rngs, train=False, *fake_input_batch)[
+ 'params'
+ ]
# Run model once to initialize lazy layers
wave = torch.randn(2, 320000)
@@ -136,23 +153,26 @@ def get_models(workload):
input_shape = (16, 256)
target_shape = (16, 256)
jax_model = JaxTransformer(TransformerConfig).init(
- init_rngs,
- jnp.ones(input_shape, jnp.float32),
- jnp.ones(target_shape, jnp.float32))['params']
+ init_rngs,
+ jnp.ones(input_shape, jnp.float32),
+ jnp.ones(target_shape, jnp.float32),
+ )['params']
# Init PyTorch model.
pytorch_model = PyTorchTransformer()
elif workload == 'ogbg':
# Init Jax model.
fake_batch = jraph.GraphsTuple(
- n_node=jnp.asarray([1]),
- n_edge=jnp.asarray([1]),
- nodes=jnp.ones((1, 9)),
- edges=jnp.ones((1, 3)),
- globals=jnp.zeros((1, 128)),
- senders=jnp.asarray([0]),
- receivers=jnp.asarray([0]))
+ n_node=jnp.asarray([1]),
+ n_edge=jnp.asarray([1]),
+ nodes=jnp.ones((1, 9)),
+ edges=jnp.ones((1, 3)),
+ globals=jnp.zeros((1, 128)),
+ senders=jnp.asarray([0]),
+ receivers=jnp.asarray([0]),
+ )
jax_model = JaxGNN(num_outputs=128).init(
- init_rngs, fake_batch, train=False)['params']
+ init_rngs, fake_batch, train=False
+ )['params']
# Init PyTorch model.
pytorch_model = PyTorchGNN(num_outputs=128)
else:
diff --git a/tests/test_param_shapes.py b/tests/test_param_shapes.py
index df4c798d8..1badd39ed 100644
--- a/tests/test_param_shapes.py
+++ b/tests/test_param_shapes.py
@@ -7,40 +7,80 @@
# isort: skip_file
# pylint:disable=line-too-long
-from algoperf.workloads.cifar.cifar_jax.workload import CifarWorkload as JaxCifarWorkload
-from algoperf.workloads.cifar.cifar_pytorch.workload import CifarWorkload as PyTorchCifarWorkload
-from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import Criteo1TbDlrmSmallWorkload as JaxCriteoWorkload
-from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import Criteo1TbDlrmSmallWorkload as PyTorchCriteoWorkload
-from algoperf.workloads.fastmri.fastmri_jax.workload import FastMRIWorkload as JaxFastMRIWorkload
-from algoperf.workloads.fastmri.fastmri_pytorch.workload import FastMRIWorkload as PyTorchFastMRIWorkload
-from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import ImagenetResNetWorkload as JaxImagenetResNetWorkload
-from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import ImagenetResNetWorkload as PyTorchImagenetResNetWorkload
-from algoperf.workloads.imagenet_vit.imagenet_jax.workload import ImagenetVitWorkload as JaxImagenetViTWorkload
-from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import ImagenetVitWorkload as PyTorchImagenetViTWorkload
-from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import LibriSpeechConformerWorkload as JaxLibriSpeechConformerWorkload
-from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import LibriSpeechConformerWorkload as PytorchLibriSpeechConformerWorkload
-from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import LibriSpeechDeepSpeechWorkload as JaxLibriSpeechDeepSpeechWorkload
-from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import LibriSpeechDeepSpeechWorkload as PytorchLibriSpeechDeepSpeechWorkload
-from algoperf.workloads.mnist.mnist_jax.workload import MnistWorkload as JaxMnistWorkload
-from algoperf.workloads.mnist.mnist_pytorch.workload import MnistWorkload as PyTorchMnistWorkload
-from algoperf.workloads.ogbg.ogbg_jax.workload import OgbgWorkload as JaxOgbgWorkload
-from algoperf.workloads.ogbg.ogbg_pytorch.workload import OgbgWorkload as PyTorchOgbgWorkload
-from algoperf.workloads.wmt.wmt_jax.workload import WmtWorkload as JaxWmtWorkload
-from algoperf.workloads.wmt.wmt_pytorch.workload import WmtWorkload as PyTorchWmtWorkload
+from algoperf.workloads.cifar.cifar_jax.workload import (
+ CifarWorkload as JaxCifarWorkload,
+)
+from algoperf.workloads.cifar.cifar_pytorch.workload import (
+ CifarWorkload as PyTorchCifarWorkload,
+)
+from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import (
+ Criteo1TbDlrmSmallWorkload as JaxCriteoWorkload,
+)
+from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import (
+ Criteo1TbDlrmSmallWorkload as PyTorchCriteoWorkload,
+)
+from algoperf.workloads.fastmri.fastmri_jax.workload import (
+ FastMRIWorkload as JaxFastMRIWorkload,
+)
+from algoperf.workloads.fastmri.fastmri_pytorch.workload import (
+ FastMRIWorkload as PyTorchFastMRIWorkload,
+)
+from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import (
+ ImagenetResNetWorkload as JaxImagenetResNetWorkload,
+)
+from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import (
+ ImagenetResNetWorkload as PyTorchImagenetResNetWorkload,
+)
+from algoperf.workloads.imagenet_vit.imagenet_jax.workload import (
+ ImagenetVitWorkload as JaxImagenetViTWorkload,
+)
+from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import (
+ ImagenetVitWorkload as PyTorchImagenetViTWorkload,
+)
+from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import (
+ LibriSpeechConformerWorkload as JaxLibriSpeechConformerWorkload,
+)
+from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import (
+ LibriSpeechConformerWorkload as PytorchLibriSpeechConformerWorkload,
+)
+from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import (
+ LibriSpeechDeepSpeechWorkload as JaxLibriSpeechDeepSpeechWorkload,
+)
+from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import (
+ LibriSpeechDeepSpeechWorkload as PytorchLibriSpeechDeepSpeechWorkload,
+)
+from algoperf.workloads.mnist.mnist_jax.workload import (
+ MnistWorkload as JaxMnistWorkload,
+)
+from algoperf.workloads.mnist.mnist_pytorch.workload import (
+ MnistWorkload as PyTorchMnistWorkload,
+)
+from algoperf.workloads.ogbg.ogbg_jax.workload import (
+ OgbgWorkload as JaxOgbgWorkload,
+)
+from algoperf.workloads.ogbg.ogbg_pytorch.workload import (
+ OgbgWorkload as PyTorchOgbgWorkload,
+)
+from algoperf.workloads.wmt.wmt_jax.workload import (
+ WmtWorkload as JaxWmtWorkload,
+)
+from algoperf.workloads.wmt.wmt_pytorch.workload import (
+ WmtWorkload as PyTorchWmtWorkload,
+)
# pylint:enable=line-too-long
WORKLOADS = [
- 'cifar',
- 'criteo1tb',
- 'fastmri',
- 'imagenet_resnet',
- 'imagenet_vit',
- # TODO: make tests work for these.
- # 'librispeech_conformer',
- # 'librispeech_deepspeech',
- 'mnist',
- 'ogbg',
- 'wmt',
+ 'cifar',
+ 'criteo1tb',
+ 'fastmri',
+ 'imagenet_resnet',
+ 'imagenet_vit',
+ # TODO: make tests work for these.
+ # 'librispeech_conformer',
+ # 'librispeech_deepspeech',
+ 'mnist',
+ 'ogbg',
+ 'wmt',
]
@@ -56,9 +96,11 @@ def test_param_shapes(workload):
if isinstance(jax_workload_param_shapes, dict):
jax_workload_param_shapes = FrozenDict(jax_workload_param_shapes)
jax_param_shapes = jax.tree_util.tree_leaves(
- jax_workload_param_shapes.unfreeze())
+ jax_workload_param_shapes.unfreeze()
+ )
pytorch_param_shapes = jax.tree_util.tree_leaves(
- pytorch_workload.param_shapes)
+ pytorch_workload.param_shapes
+ )
if workload == 'wmt':
# The PyTorch transformer for WMT is implemented with fused linear layers
# for the projection of QKV inside of the MultiheadAttention module.
@@ -74,8 +116,9 @@ def test_param_shapes(workload):
# Check if total number of params deduced from shapes match.
num_jax_params = 0
num_pytorch_params = 0
- for jax_shape, pytorch_shape in zip_longest(jax_param_shapes,
- pytorch_param_shapes):
+ for jax_shape, pytorch_shape in zip_longest(
+ jax_param_shapes, pytorch_param_shapes
+ ):
if jax_shape is not None:
num_jax_params += np.prod(jax_shape.shape_tuple)
if pytorch_shape is not None:
diff --git a/tests/test_param_types.py b/tests/test_param_types.py
index d3722ae86..1583342ff 100644
--- a/tests/test_param_types.py
+++ b/tests/test_param_types.py
@@ -6,39 +6,79 @@
# isort: skip_file
# pylint:disable=line-too-long
-from algoperf.workloads.cifar.cifar_jax.workload import CifarWorkload as JaxCifarWorkload
-from algoperf.workloads.cifar.cifar_pytorch.workload import CifarWorkload as PyTorchCifarWorkload
-from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import Criteo1TbDlrmSmallWorkload as JaxCriteoWorkload
-from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import Criteo1TbDlrmSmallWorkload as PyTorchCriteoWorkload
-from algoperf.workloads.fastmri.fastmri_jax.workload import FastMRIWorkload as JaxFastMRIWorkload
-from algoperf.workloads.fastmri.fastmri_pytorch.workload import FastMRIWorkload as PyTorchFastMRIWorkload
-from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import ImagenetResNetWorkload as JaxImagenetResNetWorkload
-from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import ImagenetResNetWorkload as PyTorchImagenetResNetWorkload
-from algoperf.workloads.imagenet_vit.imagenet_jax.workload import ImagenetVitWorkload as JaxImagenetViTWorkload
-from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import ImagenetVitWorkload as PyTorchImagenetViTWorkload
-from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import LibriSpeechConformerWorkload as JaxLibriSpeechConformerWorkload
-from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import LibriSpeechConformerWorkload as PytorchLibriSpeechConformerWorkload
-from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import LibriSpeechDeepSpeechWorkload as JaxLibriSpeechDeepSpeechWorkload
-from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import LibriSpeechDeepSpeechWorkload as PytorchLibriSpeechDeepSpeechWorkload
-from algoperf.workloads.mnist.mnist_jax.workload import MnistWorkload as JaxMnistWorkload
-from algoperf.workloads.mnist.mnist_pytorch.workload import MnistWorkload as PyTorchMnistWorkload
-from algoperf.workloads.ogbg.ogbg_jax.workload import OgbgWorkload as JaxOgbgWorkload
-from algoperf.workloads.ogbg.ogbg_pytorch.workload import OgbgWorkload as PyTorchOgbgWorkload
-from algoperf.workloads.wmt.wmt_jax.workload import WmtWorkload as JaxWmtWorkload
-from algoperf.workloads.wmt.wmt_pytorch.workload import WmtWorkload as PyTorchWmtWorkload
+from algoperf.workloads.cifar.cifar_jax.workload import (
+ CifarWorkload as JaxCifarWorkload,
+)
+from algoperf.workloads.cifar.cifar_pytorch.workload import (
+ CifarWorkload as PyTorchCifarWorkload,
+)
+from algoperf.workloads.criteo1tb.criteo1tb_jax.workload import (
+ Criteo1TbDlrmSmallWorkload as JaxCriteoWorkload,
+)
+from algoperf.workloads.criteo1tb.criteo1tb_pytorch.workload import (
+ Criteo1TbDlrmSmallWorkload as PyTorchCriteoWorkload,
+)
+from algoperf.workloads.fastmri.fastmri_jax.workload import (
+ FastMRIWorkload as JaxFastMRIWorkload,
+)
+from algoperf.workloads.fastmri.fastmri_pytorch.workload import (
+ FastMRIWorkload as PyTorchFastMRIWorkload,
+)
+from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import (
+ ImagenetResNetWorkload as JaxImagenetResNetWorkload,
+)
+from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import (
+ ImagenetResNetWorkload as PyTorchImagenetResNetWorkload,
+)
+from algoperf.workloads.imagenet_vit.imagenet_jax.workload import (
+ ImagenetVitWorkload as JaxImagenetViTWorkload,
+)
+from algoperf.workloads.imagenet_vit.imagenet_pytorch.workload import (
+ ImagenetVitWorkload as PyTorchImagenetViTWorkload,
+)
+from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import (
+ LibriSpeechConformerWorkload as JaxLibriSpeechConformerWorkload,
+)
+from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import (
+ LibriSpeechConformerWorkload as PytorchLibriSpeechConformerWorkload,
+)
+from algoperf.workloads.librispeech_deepspeech.librispeech_jax.workload import (
+ LibriSpeechDeepSpeechWorkload as JaxLibriSpeechDeepSpeechWorkload,
+)
+from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.workload import (
+ LibriSpeechDeepSpeechWorkload as PytorchLibriSpeechDeepSpeechWorkload,
+)
+from algoperf.workloads.mnist.mnist_jax.workload import (
+ MnistWorkload as JaxMnistWorkload,
+)
+from algoperf.workloads.mnist.mnist_pytorch.workload import (
+ MnistWorkload as PyTorchMnistWorkload,
+)
+from algoperf.workloads.ogbg.ogbg_jax.workload import (
+ OgbgWorkload as JaxOgbgWorkload,
+)
+from algoperf.workloads.ogbg.ogbg_pytorch.workload import (
+ OgbgWorkload as PyTorchOgbgWorkload,
+)
+from algoperf.workloads.wmt.wmt_jax.workload import (
+ WmtWorkload as JaxWmtWorkload,
+)
+from algoperf.workloads.wmt.wmt_pytorch.workload import (
+ WmtWorkload as PyTorchWmtWorkload,
+)
# pylint:enable=line-too-long
WORKLOADS = [
- 'cifar',
- 'criteo1tb',
- 'fastmri',
- 'imagenet_resnet',
- 'imagenet_vit',
- 'librispeech_conformer',
- 'librispeech_deepspeech',
- 'mnist',
- 'ogbg',
- 'wmt',
+ 'cifar',
+ 'criteo1tb',
+ 'fastmri',
+ 'imagenet_resnet',
+ 'imagenet_vit',
+ 'librispeech_conformer',
+ 'librispeech_deepspeech',
+ 'mnist',
+ 'ogbg',
+ 'wmt',
]
@@ -66,40 +106,32 @@ def _check_attention_qkv_match(jax_param_types_dict, pytorch_param_types_dict):
# Sometimes one framework will implement QKV as a single parameter, so we need
# to make sure there are the same number of QKV params as Q, K, V.
num_qkv = {
- 'jax':
- jax_param_types_dict.get(spec.ParameterType.ATTENTION_QKV, 0),
- 'pytorch':
- pytorch_param_types_dict.get(spec.ParameterType.ATTENTION_QKV, 0),
+ 'jax': jax_param_types_dict.get(spec.ParameterType.ATTENTION_QKV, 0),
+ 'pytorch': pytorch_param_types_dict.get(
+ spec.ParameterType.ATTENTION_QKV, 0
+ ),
}
num_kv = {
- 'jax':
- jax_param_types_dict.get(spec.ParameterType.ATTENTION_KV, 0),
- 'pytorch':
- pytorch_param_types_dict.get(spec.ParameterType.ATTENTION_KV, 0),
+ 'jax': jax_param_types_dict.get(spec.ParameterType.ATTENTION_KV, 0),
+ 'pytorch': pytorch_param_types_dict.get(spec.ParameterType.ATTENTION_KV, 0),
}
num_q = {
- 'jax':
- jax_param_types_dict.get(spec.ParameterType.ATTENTION_Q, 0),
- 'pytorch':
- pytorch_param_types_dict.get(spec.ParameterType.ATTENTION_Q, 0),
+ 'jax': jax_param_types_dict.get(spec.ParameterType.ATTENTION_Q, 0),
+ 'pytorch': pytorch_param_types_dict.get(spec.ParameterType.ATTENTION_Q, 0),
}
num_k = {
- 'jax':
- jax_param_types_dict.get(spec.ParameterType.ATTENTION_K, 0),
- 'pytorch':
- pytorch_param_types_dict.get(spec.ParameterType.ATTENTION_K, 0),
+ 'jax': jax_param_types_dict.get(spec.ParameterType.ATTENTION_K, 0),
+ 'pytorch': pytorch_param_types_dict.get(spec.ParameterType.ATTENTION_K, 0),
}
num_v = {
- 'jax':
- jax_param_types_dict.get(spec.ParameterType.ATTENTION_V, 0),
- 'pytorch':
- pytorch_param_types_dict.get(spec.ParameterType.ATTENTION_V, 0),
+ 'jax': jax_param_types_dict.get(spec.ParameterType.ATTENTION_V, 0),
+ 'pytorch': pytorch_param_types_dict.get(spec.ParameterType.ATTENTION_V, 0),
}
num_bias = {
- 'jax':
- jax_param_types_dict.get(spec.ParameterType.ATTENTION_BIAS, 0),
- 'pytorch':
- pytorch_param_types_dict.get(spec.ParameterType.ATTENTION_BIAS, 0),
+ 'jax': jax_param_types_dict.get(spec.ParameterType.ATTENTION_BIAS, 0),
+ 'pytorch': pytorch_param_types_dict.get(
+ spec.ParameterType.ATTENTION_BIAS, 0
+ ),
}
qkv_match = num_qkv['jax'] == num_qkv['pytorch']
kv_match = num_kv['jax'] == num_kv['pytorch']
@@ -108,24 +140,33 @@ def _check_attention_qkv_match(jax_param_types_dict, pytorch_param_types_dict):
v_match = num_v['jax'] == num_v['pytorch']
bias_match = num_bias['jax'] == num_bias['pytorch']
qkv_match = (
- qkv_match and kv_match and q_match and k_match and v_match and bias_match)
+ qkv_match and kv_match and q_match and k_match and v_match and bias_match
+ )
# We subtract 2 * num_qkv from the number of biases because there are 2
# missing for each of q, k, v.
- jax_qkv_match = (
- num_q['pytorch'] == num_k['pytorch'] == num_v['pytorch'] == num_qkv['jax']
- and (num_qkv['jax'] != 0 and
- (num_bias['pytorch'] - 2 * num_qkv['jax']) == num_bias['jax']))
- pytorch_qkv_match = (
- num_q['jax'] == num_k['jax'] == num_v['jax'] == num_qkv['pytorch'] and
- (num_qkv['pytorch'] != 0 and
- (num_bias['jax'] - 2 * num_qkv['pytorch']) == num_bias['pytorch']))
+ jax_qkv_match = num_q['pytorch'] == num_k['pytorch'] == num_v[
+ 'pytorch'
+ ] == num_qkv['jax'] and (
+ num_qkv['jax'] != 0
+ and (num_bias['pytorch'] - 2 * num_qkv['jax']) == num_bias['jax']
+ )
+ pytorch_qkv_match = num_q['jax'] == num_k['jax'] == num_v['jax'] == num_qkv[
+ 'pytorch'
+ ] and (
+ num_qkv['pytorch'] != 0
+ and (num_bias['jax'] - 2 * num_qkv['pytorch']) == num_bias['pytorch']
+ )
pytorch_kv_match = (
- num_q['jax'] == num_k['jax'] == num_v['jax'] ==
- num_qkv['pytorch'] + num_kv['pytorch'] and
- num_q['pytorch'] == num_kv['pytorch'])
+ num_q['jax']
+ == num_k['jax']
+ == num_v['jax']
+ == num_qkv['pytorch'] + num_kv['pytorch']
+ and num_q['pytorch'] == num_kv['pytorch']
+ )
qkv_match = (
- qkv_match or jax_qkv_match or pytorch_qkv_match or pytorch_kv_match)
+ qkv_match or jax_qkv_match or pytorch_qkv_match or pytorch_kv_match
+ )
return qkv_match
@@ -137,7 +178,8 @@ def test_param_types(workload_name):
# Compare number of parameter tensors of both models.
jax_param_types = jax.tree_util.tree_leaves(jax_workload.model_params_types)
pytorch_param_types = jax.tree_util.tree_leaves(
- pytorch_workload.model_params_types)
+ pytorch_workload.model_params_types
+ )
jax_param_types_dict = count_param_types(jax_param_types)
pytorch_param_types_dict = count_param_types(pytorch_param_types)
@@ -161,30 +203,33 @@ def test_param_types(workload_name):
# Check if total number of each type match.
attention_keys = {
- spec.ParameterType.ATTENTION_QKV,
- spec.ParameterType.ATTENTION_KV,
- spec.ParameterType.ATTENTION_Q,
- spec.ParameterType.ATTENTION_K,
- spec.ParameterType.ATTENTION_V,
- spec.ParameterType.ATTENTION_BIAS,
+ spec.ParameterType.ATTENTION_QKV,
+ spec.ParameterType.ATTENTION_KV,
+ spec.ParameterType.ATTENTION_Q,
+ spec.ParameterType.ATTENTION_K,
+ spec.ParameterType.ATTENTION_V,
+ spec.ParameterType.ATTENTION_BIAS,
}
non_attention_keys = set(jax_param_types_dict.keys()).union(
- set(pytorch_param_types_dict.keys()))
+ set(pytorch_param_types_dict.keys())
+ )
non_attention_keys -= attention_keys
mismatches = ''
- mismatches += _count_mismatches(jax_param_types_dict,
- pytorch_param_types_dict,
- non_attention_keys)
- qkv_match = _check_attention_qkv_match(jax_param_types_dict,
- pytorch_param_types_dict)
+ mismatches += _count_mismatches(
+ jax_param_types_dict, pytorch_param_types_dict, non_attention_keys
+ )
+ qkv_match = _check_attention_qkv_match(
+ jax_param_types_dict, pytorch_param_types_dict
+ )
if not qkv_match:
- mismatches += _count_mismatches(jax_param_types_dict,
- pytorch_param_types_dict,
- attention_keys)
+ mismatches += _count_mismatches(
+ jax_param_types_dict, pytorch_param_types_dict, attention_keys
+ )
if mismatches:
raise ValueError(
- f'On workload {workload_name}, count mismatch: {mismatches}')
+ f'On workload {workload_name}, count mismatch: {mismatches}'
+ )
def get_workload(workload_name):
diff --git a/tests/test_ssim.py b/tests/test_ssim.py
index 920556964..7d730c251 100644
--- a/tests/test_ssim.py
+++ b/tests/test_ssim.py
@@ -10,13 +10,14 @@
import torch
from algoperf.pytorch_utils import pytorch_setup
-from algoperf.workloads.fastmri.fastmri_jax.ssim import \
- _uniform_filter as _jax_uniform_filter
+from algoperf.workloads.fastmri.fastmri_jax.ssim import (
+ _uniform_filter as _jax_uniform_filter,
+)
from algoperf.workloads.fastmri.fastmri_jax.ssim import ssim as jax_ssim
-from algoperf.workloads.fastmri.fastmri_pytorch.ssim import \
- _uniform_filter as _pytorch_uniform_filter
-from algoperf.workloads.fastmri.fastmri_pytorch.ssim import \
- ssim as pytorch_ssim
+from algoperf.workloads.fastmri.fastmri_pytorch.ssim import (
+ _uniform_filter as _pytorch_uniform_filter,
+)
+from algoperf.workloads.fastmri.fastmri_pytorch.ssim import ssim as pytorch_ssim
# Make sure no GPU memory is preallocated to Jax.
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
@@ -31,7 +32,7 @@ def _create_fake_im(height: int, width: int) -> Tuple[jnp.array, torch.Tensor]:
def _create_fake_batch(
- batch_size: int, height: int, width: int
+ batch_size: int, height: int, width: int
) -> Tuple[Tuple[jnp.array, jnp.array], Tuple[torch.Tensor, torch.Tensor]]:
logits = np.random.randn(batch_size, height, width)
targets = np.random.randn(batch_size, height, width)
@@ -47,9 +48,9 @@ class SSIMTest(parameterized.TestCase):
and PyTorch."""
@parameterized.named_parameters(
- dict(testcase_name='fastmri_im', height=320, width=320),
- dict(testcase_name='uneven_even_im', height=31, width=16),
- dict(testcase_name='even_uneven_im', height=42, width=53),
+ dict(testcase_name='fastmri_im', height=320, width=320),
+ dict(testcase_name='uneven_even_im', height=31, width=16),
+ dict(testcase_name='even_uneven_im', height=42, width=53),
)
def test_uniform_filter(self, height: int, width: int) -> None:
jax_im, pytorch_im = _create_fake_im(height, width)
@@ -58,12 +59,9 @@ def test_uniform_filter(self, height: int, width: int) -> None:
assert np.allclose(jax_result, torch_result, atol=1e-6)
@parameterized.named_parameters(
- dict(
- testcase_name='fastmri_batch', batch_size=256, height=320, width=320),
- dict(
- testcase_name='uneven_even_batch', batch_size=8, height=31, width=16),
- dict(
- testcase_name='even_uneven_batch', batch_size=8, height=42, width=53),
+ dict(testcase_name='fastmri_batch', batch_size=256, height=320, width=320),
+ dict(testcase_name='uneven_even_batch', batch_size=8, height=31, width=16),
+ dict(testcase_name='even_uneven_batch', batch_size=8, height=42, width=53),
)
def test_ssim(self, batch_size: int, height: int, width: int) -> None:
jax_inputs, pytorch_inputs = _create_fake_batch(batch_size, height, width)
@@ -71,9 +69,8 @@ def test_ssim(self, batch_size: int, height: int, width: int) -> None:
pytorch_ssim_result = pytorch_ssim(*pytorch_inputs)
self.assertEqual(jax_ssim_result.shape, pytorch_ssim_result.shape)
assert np.allclose(
- jax_ssim_result.sum().item(),
- pytorch_ssim_result.sum().item(),
- atol=1e-6)
+ jax_ssim_result.sum().item(), pytorch_ssim_result.sum().item(), atol=1e-6
+ )
if __name__ == '__main__':
diff --git a/tests/test_traindiffs.py b/tests/test_traindiffs.py
index cea589202..b1982a3bf 100644
--- a/tests/test_traindiffs.py
+++ b/tests/test_traindiffs.py
@@ -3,6 +3,7 @@
Run it as:
python3 test_traindiffs.py
"""
+
import pickle
import subprocess
from subprocess import DEVNULL
@@ -17,14 +18,14 @@
FLAGS = flags.FLAGS
WORKLOADS = [
- 'imagenet_resnet',
- 'imagenet_vit',
- 'wmt',
- 'librispeech_conformer',
- 'librispeech_deepspeech',
- 'fastmri',
- 'ogbg',
- 'criteo1tb'
+ 'imagenet_resnet',
+ 'imagenet_vit',
+ 'wmt',
+ 'librispeech_conformer',
+ 'librispeech_deepspeech',
+ 'fastmri',
+ 'ogbg',
+ 'criteo1tb',
]
GLOBAL_BATCH_SIZE = 16
NUM_TRAIN_STEPS = 10
@@ -35,7 +36,6 @@
class ModelDiffTest(parameterized.TestCase):
-
@parameterized.named_parameters(*named_parameters)
def test_workload(self, workload):
# pylint: disable=line-too-long, unnecessary-lambda-assignment
@@ -50,24 +50,26 @@ def test_workload(self, workload):
pytorch_logs_path = '/tmp/pyt_log.pkl'
try:
run(
- f'XLA_PYTHON_CLIENT_ALLOCATOR=platform python -m tests.reference_algorithm_tests --workload={workload} --framework=jax --global_batch_size={GLOBAL_BATCH_SIZE} --log_file={jax_logs_path}'
- f' --submission_path=tests/modeldiffs/vanilla_sgd_jax.py --identical=True --tuning_search_space=None --num_train_steps={NUM_TRAIN_STEPS}',
- shell=True,
- stdout=DEVNULL,
- stderr=STDOUT,
- check=True)
+ f'XLA_PYTHON_CLIENT_ALLOCATOR=platform python -m tests.reference_algorithm_tests --workload={workload} --framework=jax --global_batch_size={GLOBAL_BATCH_SIZE} --log_file={jax_logs_path}'
+ f' --submission_path=tests/modeldiffs/vanilla_sgd_jax.py --identical=True --tuning_search_space=None --num_train_steps={NUM_TRAIN_STEPS}',
+ shell=True,
+ stdout=DEVNULL,
+ stderr=STDOUT,
+ check=True,
+ )
except subprocess.CalledProcessError as e:
- print("Error:", e)
+ print('Error:', e)
try:
run(
- f'XLA_PYTHON_CLIENT_ALLOCATOR=platform torchrun --standalone --nnodes 1 --nproc_per_node 8 -m tests.reference_algorithm_tests --workload={workload} --framework=pytorch --global_batch_size={GLOBAL_BATCH_SIZE} --log_file={pytorch_logs_path}'
- f' --submission_path=tests/modeldiffs/vanilla_sgd_pytorch.py --identical=True --tuning_search_space=None --num_train_steps={NUM_TRAIN_STEPS}',
- shell=True,
- stdout=DEVNULL,
- stderr=STDOUT,
- check=True)
+ f'XLA_PYTHON_CLIENT_ALLOCATOR=platform torchrun --standalone --nnodes 1 --nproc_per_node 8 -m tests.reference_algorithm_tests --workload={workload} --framework=pytorch --global_batch_size={GLOBAL_BATCH_SIZE} --log_file={pytorch_logs_path}'
+ f' --submission_path=tests/modeldiffs/vanilla_sgd_pytorch.py --identical=True --tuning_search_space=None --num_train_steps={NUM_TRAIN_STEPS}',
+ shell=True,
+ stdout=DEVNULL,
+ stderr=STDOUT,
+ check=True,
+ )
except subprocess.CalledProcessError as e:
- print("Error:", e)
+ print('Error:', e)
with open(jax_logs_path, 'rb') as f:
jax_results = pickle.load(f)
with open(pytorch_logs_path, 'rb') as f:
@@ -75,17 +77,20 @@ def test_workload(self, workload):
# PRINT RESULTS
eval_metric_key = next(
- iter(
- filter(lambda k: 'train' in k and 'loss' in k,
- jax_results['eval_results'][0])))
+ iter(
+ filter(
+ lambda k: 'train' in k and 'loss' in k, jax_results['eval_results'][0]
+ )
+ )
+ )
header = [
- 'Iter',
- 'Eval (jax)',
- 'Eval (torch)',
- 'Grad Norm (jax)',
- 'Grad Norm (torch)',
- 'Train Loss (jax)',
- 'Train Loss (torch)',
+ 'Iter',
+ 'Eval (jax)',
+ 'Eval (torch)',
+ 'Grad Norm (jax)',
+ 'Grad Norm (torch)',
+ 'Train Loss (jax)',
+ 'Train Loss (torch)',
]
fmt = lambda l: '|' + '|'.join(map(lambda x: f'{x:^20s}', l)) + '|'
header = fmt(header)
@@ -97,33 +102,41 @@ def test_workload(self, workload):
for i in range(NUM_TRAIN_STEPS):
rtol = 1
- row = map(lambda x: str(round(x, 5)),
- [
- jax_results['eval_results'][i][eval_metric_key],
- pytorch_results['eval_results'][i][eval_metric_key],
- jax_results['scalars'][i]['grad_norm'],
- pytorch_results['scalars'][i]['grad_norm'],
- jax_results['scalars'][i]['loss'],
- pytorch_results['scalars'][i]['loss'],
- ])
+ row = map(
+ lambda x: str(round(x, 5)),
+ [
+ jax_results['eval_results'][i][eval_metric_key],
+ pytorch_results['eval_results'][i][eval_metric_key],
+ jax_results['scalars'][i]['grad_norm'],
+ pytorch_results['scalars'][i]['grad_norm'],
+ jax_results['scalars'][i]['loss'],
+ pytorch_results['scalars'][i]['loss'],
+ ],
+ )
print(fmt([f'{i}', *row]))
print('=' * len(header))
self.assertTrue( # eval_results
- allclose(
- jax_results['eval_results'][i][eval_metric_key],
- pytorch_results['eval_results'][i][eval_metric_key],
- rtol=rtol))
+ allclose(
+ jax_results['eval_results'][i][eval_metric_key],
+ pytorch_results['eval_results'][i][eval_metric_key],
+ rtol=rtol,
+ )
+ )
self.assertTrue( # grad_norms
- allclose(
- jax_results['scalars'][i]['grad_norm'],
- pytorch_results['scalars'][i]['grad_norm'],
- rtol=rtol))
+ allclose(
+ jax_results['scalars'][i]['grad_norm'],
+ pytorch_results['scalars'][i]['grad_norm'],
+ rtol=rtol,
+ )
+ )
self.assertTrue( # loss
- allclose(
- jax_results['scalars'][i]['loss'],
- pytorch_results['scalars'][i]['loss'],
- rtol=rtol))
+ allclose(
+ jax_results['scalars'][i]['loss'],
+ pytorch_results['scalars'][i]['loss'],
+ rtol=rtol,
+ )
+ )
if __name__ == '__main__':
diff --git a/tests/test_version.py b/tests/test_version.py
index d1bfbd18f..69384953a 100644
--- a/tests/test_version.py
+++ b/tests/test_version.py
@@ -6,10 +6,10 @@
def test_version_attribute():
"""Check whether __version__ exists and is a valid string."""
- assert hasattr(algoperf, "__version__")
+ assert hasattr(algoperf, '__version__')
version = algoperf.__version__
assert isinstance(version, str)
- version_elements = version.split(".")
+ version_elements = version.split('.')
print(version_elements)
# Only check the first two elements, i.e. major, minor
# (patch is not checked as it is not required).
diff --git a/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py b/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py
index d44234927..3d06c9839 100644
--- a/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py
+++ b/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py
@@ -5,8 +5,9 @@
import jax.numpy as jnp
from algoperf import spec
-from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import \
- ImagenetResNetWorkload
+from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import (
+ ImagenetResNetWorkload,
+)
def _pytree_total_diff(pytree_a, pytree_b):
@@ -32,42 +33,48 @@ def test_forward_pass(self):
# this function because we call it with a different combination of those two
# args each time. Can't call with kwargs.
pmapped_model_fn = jax.pmap(
- workload.model_fn,
- axis_name='batch',
- in_axes=(0, 0, 0, None, None, None),
- static_broadcasted_argnums=(3, 5))
+ workload.model_fn,
+ axis_name='batch',
+ in_axes=(0, 0, 0, None, None, None),
+ static_broadcasted_argnums=(3, 5),
+ )
logits, updated_batch_stats = pmapped_model_fn(
- model_params,
- {'inputs': first_input_batch},
- batch_stats,
- spec.ForwardPassMode.TRAIN,
- rng,
- True)
+ model_params,
+ {'inputs': first_input_batch},
+ batch_stats,
+ spec.ForwardPassMode.TRAIN,
+ rng,
+ True,
+ )
self.assertEqual(logits.shape, expected_logits_shape)
# Test that batch stats are updated.
self.assertNotEqual(
- _pytree_total_diff(batch_stats, updated_batch_stats), 0.0)
+ _pytree_total_diff(batch_stats, updated_batch_stats), 0.0
+ )
second_input_batch = jax.random.normal(data_rngs[1], shape=input_shape)
# Test that batch stats are not updated when we say so.
_, same_batch_stats = pmapped_model_fn(
- model_params,
- {'inputs': second_input_batch},
- updated_batch_stats,
- spec.ForwardPassMode.TRAIN,
- rng,
- False)
+ model_params,
+ {'inputs': second_input_batch},
+ updated_batch_stats,
+ spec.ForwardPassMode.TRAIN,
+ rng,
+ False,
+ )
self.assertEqual(
- _pytree_total_diff(same_batch_stats, updated_batch_stats), 0.0)
+ _pytree_total_diff(same_batch_stats, updated_batch_stats), 0.0
+ )
# Test eval model.
logits, _ = pmapped_model_fn(
- model_params,
- {'inputs': second_input_batch},
- batch_stats,
- spec.ForwardPassMode.EVAL,
- rng,
- False)
+ model_params,
+ {'inputs': second_input_batch},
+ batch_stats,
+ spec.ForwardPassMode.EVAL,
+ rng,
+ False,
+ )
self.assertEqual(logits.shape, expected_logits_shape)
From f02671100980e742464c273a0825df9c4d99f264 Mon Sep 17 00:00:00 2001
From: Frank Schneider
Date: Mon, 23 Jun 2025 10:20:18 +0200
Subject: [PATCH 105/123] Format prize_qualification_baselines/
---
.../external_tuning/jax_nadamw_full_budget.py | 269 ++++++++--------
.../jax_nadamw_target_setting.py | 269 ++++++++--------
.../pytorch_nadamw_full_budget.py | 274 +++++++++--------
.../pytorch_nadamw_target_setting.py | 274 +++++++++--------
.../self_tuning/jax_nadamw_full_budget.py | 281 +++++++++--------
.../self_tuning/jax_nadamw_target_setting.py | 281 +++++++++--------
.../self_tuning/pytorch_nadamw_full_budget.py | 286 ++++++++++--------
.../pytorch_nadamw_target_setting.py | 286 ++++++++++--------
8 files changed, 1224 insertions(+), 996 deletions(-)
diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py
index c451a18ac..aa1a08f69 100644
--- a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py
+++ b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py
@@ -4,15 +4,17 @@
# isort: off
# We have to turn off isort here to resolve a conflict between isort and yapf.
-from typing import (Any,
- Callable,
- Dict,
- Iterator,
- List,
- NamedTuple,
- Optional,
- Tuple,
- Union)
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Iterator,
+ List,
+ NamedTuple,
+ Optional,
+ Tuple,
+ Union,
+)
# isort: on
import chex
@@ -30,15 +32,14 @@
# Forked from
# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py
def nadamw(
- learning_rate: Union[float, optax.Schedule],
- b1: float = 0.9,
- b2: float = 0.999,
- eps: float = 1e-8,
- eps_root: float = 0.0,
- debias: bool = True,
- weight_decay: float = 0.0,
- weight_decay_mask: Optional[Union[Any, Callable[[optax.Params],
- Any]]] = None,
+ learning_rate: Union[float, optax.Schedule],
+ b1: float = 0.9,
+ b2: float = 0.999,
+ eps: float = 1e-8,
+ eps_root: float = 0.0,
+ debias: bool = True,
+ weight_decay: float = 0.0,
+ weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], Any]]] = None,
) -> optax.GradientTransformation:
"""Rescale updates according to the NAdam algorithm.
@@ -73,19 +74,22 @@ def nadamw(
An (init_fn, update_fn) tuple.
"""
return optax.chain(
- scale_by_nadam(b1, b2, eps, eps_root, debias),
- optax.add_decayed_weights(weight_decay, weight_decay_mask),
- scale_by_learning_rate(learning_rate))
+ scale_by_nadam(b1, b2, eps, eps_root, debias),
+ optax.add_decayed_weights(weight_decay, weight_decay_mask),
+ scale_by_learning_rate(learning_rate),
+ )
# All functions below are forked from
# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py
-def scale_by_nadam(b1: float = 0.9,
- b2: float = 0.999,
- eps: float = 1e-8,
- eps_root: float = 0.0,
- debias: bool = True,
- power: float = 0.5) -> optax.GradientTransformation:
+def scale_by_nadam(
+ b1: float = 0.9,
+ b2: float = 0.999,
+ eps: float = 1e-8,
+ eps_root: float = 0.0,
+ debias: bool = True,
+ power: float = 0.5,
+) -> optax.GradientTransformation:
"""Rescale updates according to the NAdam algorithm.
References:
@@ -124,7 +128,8 @@ def update_fn(updates, state, params=None):
mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count)
nu_hat = nu if not debias else _bias_correction(nu, b2, count)
updates = jax.tree.map(
- lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat)
+ lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat
+ )
return updates, ScaleByAdamState(count=count, mu=mu, nu=nu)
return optax.GradientTransformation(init_fn, update_fn)
@@ -132,6 +137,7 @@ def update_fn(updates, state, params=None):
class ScaleByAdamState(NamedTuple):
"""State for the NAdam algorithm."""
+
count: chex.Array # shape=(), dtype=jnp.int32.
mu: optax.Updates
nu: optax.Updates
@@ -140,7 +146,8 @@ class ScaleByAdamState(NamedTuple):
def _update_moment(updates, moments, decay, order):
"""Compute the exponential moving average of the `order-th` moment."""
return jax.tree.map(
- lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments)
+ lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments
+ )
def _bias_correction(moment, decay, count):
@@ -156,11 +163,13 @@ def scale_by_learning_rate(learning_rate, flip_sign=True):
return optax.scale(m * learning_rate)
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a NAdamW optimizer and a learning rate schedule."""
del model_params
del model_state
@@ -170,101 +179,115 @@ def jax_cosine_warmup(step_hint: int, hyperparameters):
# Create learning rate schedule.
warmup_steps = int(hyperparameters.warmup_factor * step_hint)
warmup_fn = optax.linear_schedule(
- init_value=0.,
- end_value=hyperparameters.learning_rate,
- transition_steps=warmup_steps)
+ init_value=0.0,
+ end_value=hyperparameters.learning_rate,
+ transition_steps=warmup_steps,
+ )
cosine_steps = max(step_hint - warmup_steps, 1)
cosine_fn = optax.cosine_decay_schedule(
- init_value=hyperparameters.learning_rate, decay_steps=cosine_steps)
+ init_value=hyperparameters.learning_rate, decay_steps=cosine_steps
+ )
schedule_fn = optax.join_schedules(
- schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps])
+ schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]
+ )
return schedule_fn
# Create optimizer + LR schedule.
lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters)
opt_init_fn, opt_update_fn = nadamw(
- learning_rate=lr_schedule_fn,
- b1=1.0 - hyperparameters.one_minus_beta1,
- b2=hyperparameters.beta2,
- eps=1e-8,
- weight_decay=hyperparameters.weight_decay)
- params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple),
- workload.param_shapes)
+ learning_rate=lr_schedule_fn,
+ b1=1.0 - hyperparameters.one_minus_beta1,
+ b2=hyperparameters.beta2,
+ eps=1e-8,
+ weight_decay=hyperparameters.weight_decay,
+ )
+ params_zeros_like = jax.tree.map(
+ lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes
+ )
optimizer_state = opt_init_fn(params_zeros_like)
return jax_utils.replicate(optimizer_state), opt_update_fn
@functools.partial(
- jax.pmap,
- axis_name='batch',
- in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
- static_broadcasted_argnums=(0, 1),
- donate_argnums=(2, 3, 4))
-def pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- batch,
- rng,
- grad_clip,
- label_smoothing):
-
+ jax.pmap,
+ axis_name='batch',
+ in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
+ static_broadcasted_argnums=(0, 1),
+ donate_argnums=(2, 3, 4),
+)
+def pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ rng,
+ grad_clip,
+ label_smoothing,
+):
def _loss_fn(params):
"""Loss function used for training."""
logits, new_model_state = workload.model_fn(
- params,
- batch,
- model_state,
- spec.ForwardPassMode.TRAIN,
- rng,
- update_batch_norm=True)
+ params,
+ batch,
+ model_state,
+ spec.ForwardPassMode.TRAIN,
+ rng,
+ update_batch_norm=True,
+ )
loss_dict = workload.loss_fn(
- label_batch=batch['targets'],
- logits_batch=logits,
- mask_batch=batch.get('weights'),
- label_smoothing=label_smoothing)
+ label_batch=batch['targets'],
+ logits_batch=logits,
+ mask_batch=batch.get('weights'),
+ label_smoothing=label_smoothing,
+ )
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
return summed_loss, (n_valid_examples, new_model_state)
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
(summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn(
- current_param_container)
+ current_param_container
+ )
# Get correct global mean loss and grad.
(summed_loss, n_valid_examples, grad) = lax.psum(
- (summed_loss, n_valid_examples, grad), axis_name='batch')
+ (summed_loss, n_valid_examples, grad), axis_name='batch'
+ )
loss = summed_loss / n_valid_examples
grad = jax.tree.map(lambda x: x / n_valid_examples, grad)
grad_norm = jnp.sqrt(
- sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)))
+ sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))
+ )
if grad_clip is not None:
grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS)
grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0)
grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad)
- updates, new_optimizer_state = opt_update_fn(grad, optimizer_state,
- current_param_container)
+ updates, new_optimizer_state = opt_update_fn(
+ grad, optimizer_state, current_param_container
+ )
updated_params = optax.apply_updates(current_param_container, updates)
return new_optimizer_state, updated_params, new_model_state, loss, grad_norm
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -281,37 +304,43 @@ def update_params(
grad_clip = hyperparameters.grad_clip
else:
grad_clip = None
- outputs = pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- batch,
- per_device_rngs,
- grad_clip,
- label_smoothing)
+ outputs = pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ per_device_rngs,
+ grad_clip,
+ label_smoothing,
+ )
new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs
# Log loss, grad_norm.
if global_step % 100 == 0 and workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
- {
- 'loss': loss[0],
- 'grad_norm': grad_norm[0],
- }, global_step)
+ {
+ 'loss': loss[0],
+ 'grad_norm': grad_norm[0],
+ },
+ global_step,
+ )
return (new_optimizer_state, opt_update_fn), new_params, new_model_state
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -351,14 +380,16 @@ def get_batch_size(workload_name):
raise ValueError(f'Unsupported workload name: {workload_name}.')
-def data_selection(workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+def data_selection(
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py
index b8ac10f33..f1d7d62e0 100644
--- a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py
+++ b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py
@@ -4,15 +4,17 @@
# isort: off
# We have to turn off isort here to resolve a conflict between isort and yapf.
-from typing import (Any,
- Callable,
- Dict,
- Iterator,
- List,
- NamedTuple,
- Optional,
- Tuple,
- Union)
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Iterator,
+ List,
+ NamedTuple,
+ Optional,
+ Tuple,
+ Union,
+)
# isort: on
import chex
@@ -30,15 +32,14 @@
# Forked from
# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py
def nadamw(
- learning_rate: Union[float, optax.Schedule],
- b1: float = 0.9,
- b2: float = 0.999,
- eps: float = 1e-8,
- eps_root: float = 0.0,
- debias: bool = True,
- weight_decay: float = 0.0,
- weight_decay_mask: Optional[Union[Any, Callable[[optax.Params],
- Any]]] = None,
+ learning_rate: Union[float, optax.Schedule],
+ b1: float = 0.9,
+ b2: float = 0.999,
+ eps: float = 1e-8,
+ eps_root: float = 0.0,
+ debias: bool = True,
+ weight_decay: float = 0.0,
+ weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], Any]]] = None,
) -> optax.GradientTransformation:
"""Rescale updates according to the NAdam algorithm.
@@ -73,19 +74,22 @@ def nadamw(
An (init_fn, update_fn) tuple.
"""
return optax.chain(
- scale_by_nadam(b1, b2, eps, eps_root, debias),
- optax.add_decayed_weights(weight_decay, weight_decay_mask),
- scale_by_learning_rate(learning_rate))
+ scale_by_nadam(b1, b2, eps, eps_root, debias),
+ optax.add_decayed_weights(weight_decay, weight_decay_mask),
+ scale_by_learning_rate(learning_rate),
+ )
# All functions below are forked from
# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py
-def scale_by_nadam(b1: float = 0.9,
- b2: float = 0.999,
- eps: float = 1e-8,
- eps_root: float = 0.0,
- debias: bool = True,
- power: float = 0.5) -> optax.GradientTransformation:
+def scale_by_nadam(
+ b1: float = 0.9,
+ b2: float = 0.999,
+ eps: float = 1e-8,
+ eps_root: float = 0.0,
+ debias: bool = True,
+ power: float = 0.5,
+) -> optax.GradientTransformation:
"""Rescale updates according to the NAdam algorithm.
References:
@@ -124,7 +128,8 @@ def update_fn(updates, state, params=None):
mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count)
nu_hat = nu if not debias else _bias_correction(nu, b2, count)
updates = jax.tree.map(
- lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat)
+ lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat
+ )
return updates, ScaleByAdamState(count=count, mu=mu, nu=nu)
return optax.GradientTransformation(init_fn, update_fn)
@@ -132,6 +137,7 @@ def update_fn(updates, state, params=None):
class ScaleByAdamState(NamedTuple):
"""State for the NAdam algorithm."""
+
count: chex.Array # shape=(), dtype=jnp.int32.
mu: optax.Updates
nu: optax.Updates
@@ -140,7 +146,8 @@ class ScaleByAdamState(NamedTuple):
def _update_moment(updates, moments, decay, order):
"""Compute the exponential moving average of the `order-th` moment."""
return jax.tree.map(
- lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments)
+ lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments
+ )
def _bias_correction(moment, decay, count):
@@ -156,11 +163,13 @@ def scale_by_learning_rate(learning_rate, flip_sign=True):
return optax.scale(m * learning_rate)
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a NAdamW optimizer and a learning rate schedule."""
del model_params
del model_state
@@ -170,101 +179,115 @@ def jax_cosine_warmup(step_hint: int, hyperparameters):
# Create learning rate schedule.
warmup_steps = int(hyperparameters.warmup_factor * step_hint)
warmup_fn = optax.linear_schedule(
- init_value=0.,
- end_value=hyperparameters.learning_rate,
- transition_steps=warmup_steps)
+ init_value=0.0,
+ end_value=hyperparameters.learning_rate,
+ transition_steps=warmup_steps,
+ )
cosine_steps = max(step_hint - warmup_steps, 1)
cosine_fn = optax.cosine_decay_schedule(
- init_value=hyperparameters.learning_rate, decay_steps=cosine_steps)
+ init_value=hyperparameters.learning_rate, decay_steps=cosine_steps
+ )
schedule_fn = optax.join_schedules(
- schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps])
+ schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]
+ )
return schedule_fn
# Create optimizer + LR schedule.
lr_schedule_fn = jax_cosine_warmup(workload.step_hint * 0.75, hyperparameters)
opt_init_fn, opt_update_fn = nadamw(
- learning_rate=lr_schedule_fn,
- b1=1.0 - hyperparameters.one_minus_beta1,
- b2=hyperparameters.beta2,
- eps=1e-8,
- weight_decay=hyperparameters.weight_decay)
- params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple),
- workload.param_shapes)
+ learning_rate=lr_schedule_fn,
+ b1=1.0 - hyperparameters.one_minus_beta1,
+ b2=hyperparameters.beta2,
+ eps=1e-8,
+ weight_decay=hyperparameters.weight_decay,
+ )
+ params_zeros_like = jax.tree.map(
+ lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes
+ )
optimizer_state = opt_init_fn(params_zeros_like)
return jax_utils.replicate(optimizer_state), opt_update_fn
@functools.partial(
- jax.pmap,
- axis_name='batch',
- in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
- static_broadcasted_argnums=(0, 1),
- donate_argnums=(2, 3, 4))
-def pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- batch,
- rng,
- grad_clip,
- label_smoothing):
-
+ jax.pmap,
+ axis_name='batch',
+ in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
+ static_broadcasted_argnums=(0, 1),
+ donate_argnums=(2, 3, 4),
+)
+def pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ rng,
+ grad_clip,
+ label_smoothing,
+):
def _loss_fn(params):
"""Loss function used for training."""
logits, new_model_state = workload.model_fn(
- params,
- batch,
- model_state,
- spec.ForwardPassMode.TRAIN,
- rng,
- update_batch_norm=True)
+ params,
+ batch,
+ model_state,
+ spec.ForwardPassMode.TRAIN,
+ rng,
+ update_batch_norm=True,
+ )
loss_dict = workload.loss_fn(
- label_batch=batch['targets'],
- logits_batch=logits,
- mask_batch=batch.get('weights'),
- label_smoothing=label_smoothing)
+ label_batch=batch['targets'],
+ logits_batch=logits,
+ mask_batch=batch.get('weights'),
+ label_smoothing=label_smoothing,
+ )
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
return summed_loss, (n_valid_examples, new_model_state)
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
(summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn(
- current_param_container)
+ current_param_container
+ )
# Get correct global mean loss and grad.
(summed_loss, n_valid_examples, grad) = lax.psum(
- (summed_loss, n_valid_examples, grad), axis_name='batch')
+ (summed_loss, n_valid_examples, grad), axis_name='batch'
+ )
loss = summed_loss / n_valid_examples
grad = jax.tree.map(lambda x: x / n_valid_examples, grad)
grad_norm = jnp.sqrt(
- sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)))
+ sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))
+ )
if grad_clip is not None:
grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS)
grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0)
grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad)
- updates, new_optimizer_state = opt_update_fn(grad, optimizer_state,
- current_param_container)
+ updates, new_optimizer_state = opt_update_fn(
+ grad, optimizer_state, current_param_container
+ )
updated_params = optax.apply_updates(current_param_container, updates)
return new_optimizer_state, updated_params, new_model_state, loss, grad_norm
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -281,37 +304,43 @@ def update_params(
grad_clip = hyperparameters.grad_clip
else:
grad_clip = None
- outputs = pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- batch,
- per_device_rngs,
- grad_clip,
- label_smoothing)
+ outputs = pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ per_device_rngs,
+ grad_clip,
+ label_smoothing,
+ )
new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs
# Log loss, grad_norm.
if global_step % 100 == 0 and workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
- {
- 'loss': loss[0],
- 'grad_norm': grad_norm[0],
- }, global_step)
+ {
+ 'loss': loss[0],
+ 'grad_norm': grad_norm[0],
+ },
+ global_step,
+ )
return (new_optimizer_state, opt_update_fn), new_params, new_model_state
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -351,14 +380,16 @@ def get_batch_size(workload_name):
raise ValueError(f'Unsupported workload name: {workload_name}.')
-def data_selection(workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+def data_selection(
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
diff --git a/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py
index a2f9fb4c5..ecd299988 100644
--- a/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py
+++ b/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py
@@ -21,33 +21,30 @@
class NAdamW(torch.optim.Optimizer):
r"""Implements NAdamW algorithm.
- See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of
- the NAdam algorithm (there is also a comment in the code which highlights
- the only difference of NAdamW and AdamW).
- For further details regarding the algorithm we refer to
- `Decoupled Weight Decay Regularization`_.
-
- Args:
- params (iterable): iterable of parameters to optimize or dicts defining
- parameter groups
- lr (float, optional): learning rate (default: 1e-3)
- betas (Tuple[float, float], optional): coefficients used for computing
- running averages of gradient and its square (default: (0.9, 0.999))
- eps (float, optional): term added to the denominator to improve
- numerical stability (default: 1e-8)
- weight_decay (float, optional): weight decay coefficient (default: 1e-2)
- .. _Decoupled Weight Decay Regularization:
- https://arxiv.org/abs/1711.05101
- .. _On the Convergence of Adam and Beyond:
- https://openreview.net/forum?id=ryQu7f-RZ
+ See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of
+ the NAdam algorithm (there is also a comment in the code which highlights
+ the only difference of NAdamW and AdamW).
+ For further details regarding the algorithm we refer to
+ `Decoupled Weight Decay Regularization`_.
+
+ Args:
+ params (iterable): iterable of parameters to optimize or dicts defining
+ parameter groups
+ lr (float, optional): learning rate (default: 1e-3)
+ betas (Tuple[float, float], optional): coefficients used for computing
+ running averages of gradient and its square (default: (0.9, 0.999))
+ eps (float, optional): term added to the denominator to improve
+ numerical stability (default: 1e-8)
+ weight_decay (float, optional): weight decay coefficient (default: 1e-2)
+ .. _Decoupled Weight Decay Regularization:
+ https://arxiv.org/abs/1711.05101
+ .. _On the Convergence of Adam and Beyond:
+ https://openreview.net/forum?id=ryQu7f-RZ
"""
- def __init__(self,
- params,
- lr=1e-3,
- betas=(0.9, 0.999),
- eps=1e-8,
- weight_decay=1e-2):
+ def __init__(
+ self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2
+ ):
if not 0.0 <= lr:
raise ValueError(f'Invalid learning rate: {lr}')
if not 0.0 <= eps:
@@ -59,7 +56,10 @@ def __init__(self,
if not 0.0 <= weight_decay:
raise ValueError(f'Invalid weight_decay value: {weight_decay}')
defaults = {
- 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay
+ 'lr': lr,
+ 'betas': betas,
+ 'eps': eps,
+ 'weight_decay': weight_decay,
}
super().__init__(params, defaults)
@@ -67,7 +67,8 @@ def __setstate__(self, state):
super().__setstate__(state)
state_values = list(self.state.values())
step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
- state_values[0]['step'])
+ state_values[0]['step']
+ )
if not step_is_tensor:
for s in state_values:
s['step'] = torch.tensor(float(s['step']))
@@ -76,9 +77,9 @@ def __setstate__(self, state):
def step(self, closure=None):
"""Performs a single optimization step.
- Args:
- closure (callable, optional): A closure that reevaluates the model
- and returns the loss.
+ Args:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
"""
self._cuda_graph_capture_health_check()
@@ -107,51 +108,57 @@ def step(self, closure=None):
# State initialization
if len(state) == 0:
- state['step'] = torch.tensor(0.)
+ state['step'] = torch.tensor(0.0)
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(
- p, memory_format=torch.preserve_format)
+ p, memory_format=torch.preserve_format
+ )
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(
- p, memory_format=torch.preserve_format)
+ p, memory_format=torch.preserve_format
+ )
exp_avgs.append(state['exp_avg'])
exp_avg_sqs.append(state['exp_avg_sq'])
state_steps.append(state['step'])
nadamw(
- params_with_grad,
- grads,
- exp_avgs,
- exp_avg_sqs,
- state_steps,
- beta1=beta1,
- beta2=beta2,
- lr=group['lr'],
- weight_decay=group['weight_decay'],
- eps=group['eps'])
+ params_with_grad,
+ grads,
+ exp_avgs,
+ exp_avg_sqs,
+ state_steps,
+ beta1=beta1,
+ beta2=beta2,
+ lr=group['lr'],
+ weight_decay=group['weight_decay'],
+ eps=group['eps'],
+ )
return loss
-def nadamw(params: List[Tensor],
- grads: List[Tensor],
- exp_avgs: List[Tensor],
- exp_avg_sqs: List[Tensor],
- state_steps: List[Tensor],
- beta1: float,
- beta2: float,
- lr: float,
- weight_decay: float,
- eps: float) -> None:
+def nadamw(
+ params: List[Tensor],
+ grads: List[Tensor],
+ exp_avgs: List[Tensor],
+ exp_avg_sqs: List[Tensor],
+ state_steps: List[Tensor],
+ beta1: float,
+ beta2: float,
+ lr: float,
+ weight_decay: float,
+ eps: float,
+) -> None:
r"""Functional API that performs NAdamW algorithm computation.
- See NAdamW class for details.
+ See NAdamW class for details.
"""
if not all(isinstance(t, torch.Tensor) for t in state_steps):
raise RuntimeError(
- 'API has changed, `state_steps` argument must contain a list of' +
- ' singleton tensors')
+ 'API has changed, `state_steps` argument must contain a list of'
+ + ' singleton tensors'
+ )
for i, param in enumerate(params):
grad = grads[i]
@@ -189,54 +196,59 @@ def nadamw(params: List[Tensor],
exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1)
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a NAdamW optimizer and a learning rate schedule."""
del model_state
del rng
optimizer_state = {
- 'optimizer':
- NAdamW(
- model_params.parameters(),
- lr=hyperparameters.learning_rate,
- betas=(1.0 - hyperparameters.one_minus_beta1,
- hyperparameters.beta2),
- eps=1e-8,
- weight_decay=hyperparameters.weight_decay),
+ 'optimizer': NAdamW(
+ model_params.parameters(),
+ lr=hyperparameters.learning_rate,
+ betas=(1.0 - hyperparameters.one_minus_beta1, hyperparameters.beta2),
+ eps=1e-8,
+ weight_decay=hyperparameters.weight_decay,
+ ),
}
def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer):
warmup_steps = int(hyperparameters.warmup_factor * step_hint)
warmup = LinearLR(
- optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps)
+ optimizer, start_factor=1e-10, end_factor=1.0, total_iters=warmup_steps
+ )
cosine_steps = max(step_hint - warmup_steps, 1)
cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps)
return SequentialLR(
- optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps])
+ optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]
+ )
optimizer_state['scheduler'] = pytorch_cosine_warmup(
- workload.step_hint, hyperparameters, optimizer_state['optimizer'])
+ workload.step_hint, hyperparameters, optimizer_state['optimizer']
+ )
return optimizer_state
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -248,26 +260,30 @@ def update_params(
optimizer_state['optimizer'].zero_grad()
logits_batch, new_model_state = workload.model_fn(
- params=current_model,
- augmented_and_preprocessed_input_batch=batch,
- model_state=model_state,
- mode=spec.ForwardPassMode.TRAIN,
- rng=rng,
- update_batch_norm=True)
+ params=current_model,
+ augmented_and_preprocessed_input_batch=batch,
+ model_state=model_state,
+ mode=spec.ForwardPassMode.TRAIN,
+ rng=rng,
+ update_batch_norm=True,
+ )
label_smoothing = (
- hyperparameters.label_smoothing if hasattr(hyperparameters,
- 'label_smoothing') else 0.0)
+ hyperparameters.label_smoothing
+ if hasattr(hyperparameters, 'label_smoothing')
+ else 0.0
+ )
if hasattr(hyperparameters, 'grad_clip'):
grad_clip = hyperparameters.grad_clip
else:
grad_clip = None
loss_dict = workload.loss_fn(
- label_batch=batch['targets'],
- logits_batch=logits_batch,
- mask_batch=batch.get('weights'),
- label_smoothing=label_smoothing)
+ label_batch=batch['targets'],
+ logits_batch=logits_batch,
+ mask_batch=batch.get('weights'),
+ label_smoothing=label_smoothing,
+ )
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
if USE_PYTORCH_DDP:
@@ -280,7 +296,8 @@ def update_params(
if grad_clip is not None:
torch.nn.utils.clip_grad_norm_(
- current_model.parameters(), max_norm=grad_clip)
+ current_model.parameters(), max_norm=grad_clip
+ )
optimizer_state['optimizer'].step()
optimizer_state['scheduler'].step()
@@ -289,31 +306,38 @@ def update_params(
with torch.no_grad():
parameters = [p for p in current_model.parameters() if p.grad is not None]
grad_norm = torch.norm(
- torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2)
+ torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2
+ )
if workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
- {
- 'loss': loss.item(),
- 'grad_norm': grad_norm.item(),
- }, global_step)
- logging.info('%d) loss = %0.3f, grad_norm = %0.3f',
- global_step,
- loss.item(),
- grad_norm.item())
+ {
+ 'loss': loss.item(),
+ 'grad_norm': grad_norm.item(),
+ },
+ global_step,
+ )
+ logging.info(
+ '%d) loss = %0.3f, grad_norm = %0.3f',
+ global_step,
+ loss.item(),
+ grad_norm.item(),
+ )
return (optimizer_state, current_param_container, new_model_state)
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -353,14 +377,16 @@ def get_batch_size(workload_name):
raise ValueError(f'Unsupported workload name: {workload_name}.')
-def data_selection(workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+def data_selection(
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
diff --git a/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py
index a37b0d341..0d8054135 100644
--- a/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py
+++ b/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py
@@ -21,33 +21,30 @@
class NAdamW(torch.optim.Optimizer):
r"""Implements NAdamW algorithm.
- See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of
- the NAdam algorithm (there is also a comment in the code which highlights
- the only difference of NAdamW and AdamW).
- For further details regarding the algorithm we refer to
- `Decoupled Weight Decay Regularization`_.
-
- Args:
- params (iterable): iterable of parameters to optimize or dicts defining
- parameter groups
- lr (float, optional): learning rate (default: 1e-3)
- betas (Tuple[float, float], optional): coefficients used for computing
- running averages of gradient and its square (default: (0.9, 0.999))
- eps (float, optional): term added to the denominator to improve
- numerical stability (default: 1e-8)
- weight_decay (float, optional): weight decay coefficient (default: 1e-2)
- .. _Decoupled Weight Decay Regularization:
- https://arxiv.org/abs/1711.05101
- .. _On the Convergence of Adam and Beyond:
- https://openreview.net/forum?id=ryQu7f-RZ
+ See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of
+ the NAdam algorithm (there is also a comment in the code which highlights
+ the only difference of NAdamW and AdamW).
+ For further details regarding the algorithm we refer to
+ `Decoupled Weight Decay Regularization`_.
+
+ Args:
+ params (iterable): iterable of parameters to optimize or dicts defining
+ parameter groups
+ lr (float, optional): learning rate (default: 1e-3)
+ betas (Tuple[float, float], optional): coefficients used for computing
+ running averages of gradient and its square (default: (0.9, 0.999))
+ eps (float, optional): term added to the denominator to improve
+ numerical stability (default: 1e-8)
+ weight_decay (float, optional): weight decay coefficient (default: 1e-2)
+ .. _Decoupled Weight Decay Regularization:
+ https://arxiv.org/abs/1711.05101
+ .. _On the Convergence of Adam and Beyond:
+ https://openreview.net/forum?id=ryQu7f-RZ
"""
- def __init__(self,
- params,
- lr=1e-3,
- betas=(0.9, 0.999),
- eps=1e-8,
- weight_decay=1e-2):
+ def __init__(
+ self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2
+ ):
if not 0.0 <= lr:
raise ValueError(f'Invalid learning rate: {lr}')
if not 0.0 <= eps:
@@ -59,7 +56,10 @@ def __init__(self,
if not 0.0 <= weight_decay:
raise ValueError(f'Invalid weight_decay value: {weight_decay}')
defaults = {
- 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay
+ 'lr': lr,
+ 'betas': betas,
+ 'eps': eps,
+ 'weight_decay': weight_decay,
}
super().__init__(params, defaults)
@@ -67,7 +67,8 @@ def __setstate__(self, state):
super().__setstate__(state)
state_values = list(self.state.values())
step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
- state_values[0]['step'])
+ state_values[0]['step']
+ )
if not step_is_tensor:
for s in state_values:
s['step'] = torch.tensor(float(s['step']))
@@ -76,9 +77,9 @@ def __setstate__(self, state):
def step(self, closure=None):
"""Performs a single optimization step.
- Args:
- closure (callable, optional): A closure that reevaluates the model
- and returns the loss.
+ Args:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
"""
self._cuda_graph_capture_health_check()
@@ -107,51 +108,57 @@ def step(self, closure=None):
# State initialization
if len(state) == 0:
- state['step'] = torch.tensor(0.)
+ state['step'] = torch.tensor(0.0)
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(
- p, memory_format=torch.preserve_format)
+ p, memory_format=torch.preserve_format
+ )
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(
- p, memory_format=torch.preserve_format)
+ p, memory_format=torch.preserve_format
+ )
exp_avgs.append(state['exp_avg'])
exp_avg_sqs.append(state['exp_avg_sq'])
state_steps.append(state['step'])
nadamw(
- params_with_grad,
- grads,
- exp_avgs,
- exp_avg_sqs,
- state_steps,
- beta1=beta1,
- beta2=beta2,
- lr=group['lr'],
- weight_decay=group['weight_decay'],
- eps=group['eps'])
+ params_with_grad,
+ grads,
+ exp_avgs,
+ exp_avg_sqs,
+ state_steps,
+ beta1=beta1,
+ beta2=beta2,
+ lr=group['lr'],
+ weight_decay=group['weight_decay'],
+ eps=group['eps'],
+ )
return loss
-def nadamw(params: List[Tensor],
- grads: List[Tensor],
- exp_avgs: List[Tensor],
- exp_avg_sqs: List[Tensor],
- state_steps: List[Tensor],
- beta1: float,
- beta2: float,
- lr: float,
- weight_decay: float,
- eps: float) -> None:
+def nadamw(
+ params: List[Tensor],
+ grads: List[Tensor],
+ exp_avgs: List[Tensor],
+ exp_avg_sqs: List[Tensor],
+ state_steps: List[Tensor],
+ beta1: float,
+ beta2: float,
+ lr: float,
+ weight_decay: float,
+ eps: float,
+) -> None:
r"""Functional API that performs NAdamW algorithm computation.
- See NAdamW class for details.
+ See NAdamW class for details.
"""
if not all(isinstance(t, torch.Tensor) for t in state_steps):
raise RuntimeError(
- 'API has changed, `state_steps` argument must contain a list of' +
- ' singleton tensors')
+ 'API has changed, `state_steps` argument must contain a list of'
+ + ' singleton tensors'
+ )
for i, param in enumerate(params):
grad = grads[i]
@@ -189,54 +196,59 @@ def nadamw(params: List[Tensor],
exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1)
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a NAdamW optimizer and a learning rate schedule."""
del model_state
del rng
optimizer_state = {
- 'optimizer':
- NAdamW(
- model_params.parameters(),
- lr=hyperparameters.learning_rate,
- betas=(1.0 - hyperparameters.one_minus_beta1,
- hyperparameters.beta2),
- eps=1e-8,
- weight_decay=hyperparameters.weight_decay),
+ 'optimizer': NAdamW(
+ model_params.parameters(),
+ lr=hyperparameters.learning_rate,
+ betas=(1.0 - hyperparameters.one_minus_beta1, hyperparameters.beta2),
+ eps=1e-8,
+ weight_decay=hyperparameters.weight_decay,
+ ),
}
def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer):
warmup_steps = int(hyperparameters.warmup_factor * step_hint)
warmup = LinearLR(
- optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps)
+ optimizer, start_factor=1e-10, end_factor=1.0, total_iters=warmup_steps
+ )
cosine_steps = max(step_hint - warmup_steps, 1)
cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps)
return SequentialLR(
- optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps])
+ optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]
+ )
optimizer_state['scheduler'] = pytorch_cosine_warmup(
- workload.step_hint * 0.75, hyperparameters, optimizer_state['optimizer'])
+ workload.step_hint * 0.75, hyperparameters, optimizer_state['optimizer']
+ )
return optimizer_state
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -248,26 +260,30 @@ def update_params(
optimizer_state['optimizer'].zero_grad()
logits_batch, new_model_state = workload.model_fn(
- params=current_model,
- augmented_and_preprocessed_input_batch=batch,
- model_state=model_state,
- mode=spec.ForwardPassMode.TRAIN,
- rng=rng,
- update_batch_norm=True)
+ params=current_model,
+ augmented_and_preprocessed_input_batch=batch,
+ model_state=model_state,
+ mode=spec.ForwardPassMode.TRAIN,
+ rng=rng,
+ update_batch_norm=True,
+ )
label_smoothing = (
- hyperparameters.label_smoothing if hasattr(hyperparameters,
- 'label_smoothing') else 0.0)
+ hyperparameters.label_smoothing
+ if hasattr(hyperparameters, 'label_smoothing')
+ else 0.0
+ )
if hasattr(hyperparameters, 'grad_clip'):
grad_clip = hyperparameters.grad_clip
else:
grad_clip = None
loss_dict = workload.loss_fn(
- label_batch=batch['targets'],
- logits_batch=logits_batch,
- mask_batch=batch.get('weights'),
- label_smoothing=label_smoothing)
+ label_batch=batch['targets'],
+ logits_batch=logits_batch,
+ mask_batch=batch.get('weights'),
+ label_smoothing=label_smoothing,
+ )
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
if USE_PYTORCH_DDP:
@@ -280,7 +296,8 @@ def update_params(
if grad_clip is not None:
torch.nn.utils.clip_grad_norm_(
- current_model.parameters(), max_norm=grad_clip)
+ current_model.parameters(), max_norm=grad_clip
+ )
optimizer_state['optimizer'].step()
optimizer_state['scheduler'].step()
@@ -289,31 +306,38 @@ def update_params(
with torch.no_grad():
parameters = [p for p in current_model.parameters() if p.grad is not None]
grad_norm = torch.norm(
- torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2)
+ torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2
+ )
if workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
- {
- 'loss': loss.item(),
- 'grad_norm': grad_norm.item(),
- }, global_step)
- logging.info('%d) loss = %0.3f, grad_norm = %0.3f',
- global_step,
- loss.item(),
- grad_norm.item())
+ {
+ 'loss': loss.item(),
+ 'grad_norm': grad_norm.item(),
+ },
+ global_step,
+ )
+ logging.info(
+ '%d) loss = %0.3f, grad_norm = %0.3f',
+ global_step,
+ loss.item(),
+ grad_norm.item(),
+ )
return (optimizer_state, current_param_container, new_model_state)
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -353,14 +377,16 @@ def get_batch_size(workload_name):
raise ValueError(f'Unsupported workload name: {workload_name}.')
-def data_selection(workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+def data_selection(
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py
index 78c3b5b3e..fb322bd5a 100644
--- a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py
+++ b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py
@@ -4,15 +4,17 @@
# isort: off
# We have to turn off isort here to resolve a conflict between isort and yapf.
-from typing import (Any,
- Callable,
- Dict,
- Iterator,
- List,
- NamedTuple,
- Optional,
- Tuple,
- Union)
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Iterator,
+ List,
+ NamedTuple,
+ Optional,
+ Tuple,
+ Union,
+)
# isort: on
import chex
@@ -27,27 +29,26 @@
_GRAD_CLIP_EPS = 1e-6
HPARAMS = {
- "dropout_rate": 0.1,
- "learning_rate": 0.0017486387539278373,
- "one_minus_beta1": 0.06733926164,
- "beta2": 0.9955159689799007,
- "weight_decay": 0.08121616522670176,
- "warmup_factor": 0.02
+ 'dropout_rate': 0.1,
+ 'learning_rate': 0.0017486387539278373,
+ 'one_minus_beta1': 0.06733926164,
+ 'beta2': 0.9955159689799007,
+ 'weight_decay': 0.08121616522670176,
+ 'warmup_factor': 0.02,
}
# Forked from
# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py
def nadamw(
- learning_rate: Union[float, optax.Schedule],
- b1: float = 0.9,
- b2: float = 0.999,
- eps: float = 1e-8,
- eps_root: float = 0.0,
- debias: bool = True,
- weight_decay: float = 0.0,
- weight_decay_mask: Optional[Union[Any, Callable[[optax.Params],
- Any]]] = None,
+ learning_rate: Union[float, optax.Schedule],
+ b1: float = 0.9,
+ b2: float = 0.999,
+ eps: float = 1e-8,
+ eps_root: float = 0.0,
+ debias: bool = True,
+ weight_decay: float = 0.0,
+ weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], Any]]] = None,
) -> optax.GradientTransformation:
"""Rescale updates according to the NAdam algorithm.
@@ -82,19 +83,22 @@ def nadamw(
An (init_fn, update_fn) tuple.
"""
return optax.chain(
- scale_by_nadam(b1, b2, eps, eps_root, debias),
- optax.add_decayed_weights(weight_decay, weight_decay_mask),
- scale_by_learning_rate(learning_rate))
+ scale_by_nadam(b1, b2, eps, eps_root, debias),
+ optax.add_decayed_weights(weight_decay, weight_decay_mask),
+ scale_by_learning_rate(learning_rate),
+ )
# All functions below are forked from
# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py
-def scale_by_nadam(b1: float = 0.9,
- b2: float = 0.999,
- eps: float = 1e-8,
- eps_root: float = 0.0,
- debias: bool = True,
- power: float = 0.5) -> optax.GradientTransformation:
+def scale_by_nadam(
+ b1: float = 0.9,
+ b2: float = 0.999,
+ eps: float = 1e-8,
+ eps_root: float = 0.0,
+ debias: bool = True,
+ power: float = 0.5,
+) -> optax.GradientTransformation:
"""Rescale updates according to the NAdam algorithm.
References:
@@ -133,7 +137,8 @@ def update_fn(updates, state, params=None):
mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count)
nu_hat = nu if not debias else _bias_correction(nu, b2, count)
updates = jax.tree.map(
- lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat)
+ lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat
+ )
return updates, ScaleByAdamState(count=count, mu=mu, nu=nu)
return optax.GradientTransformation(init_fn, update_fn)
@@ -141,6 +146,7 @@ def update_fn(updates, state, params=None):
class ScaleByAdamState(NamedTuple):
"""State for the NAdam algorithm."""
+
count: chex.Array # shape=(), dtype=jnp.int32.
mu: optax.Updates
nu: optax.Updates
@@ -149,7 +155,8 @@ class ScaleByAdamState(NamedTuple):
def _update_moment(updates, moments, decay, order):
"""Compute the exponential moving average of the `order-th` moment."""
return jax.tree.map(
- lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments)
+ lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments
+ )
def _bias_correction(moment, decay, count):
@@ -165,11 +172,13 @@ def scale_by_learning_rate(learning_rate, flip_sign=True):
return optax.scale(m * learning_rate)
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a NAdamW optimizer and a learning rate schedule."""
del model_params
del model_state
@@ -182,101 +191,115 @@ def jax_cosine_warmup(step_hint: int, hyperparameters):
# Create learning rate schedule.
warmup_steps = int(hyperparameters['warmup_factor'] * step_hint)
warmup_fn = optax.linear_schedule(
- init_value=0.,
- end_value=hyperparameters['learning_rate'],
- transition_steps=warmup_steps)
+ init_value=0.0,
+ end_value=hyperparameters['learning_rate'],
+ transition_steps=warmup_steps,
+ )
cosine_steps = max(step_hint - warmup_steps, 1)
cosine_fn = optax.cosine_decay_schedule(
- init_value=hyperparameters['learning_rate'], decay_steps=cosine_steps)
+ init_value=hyperparameters['learning_rate'], decay_steps=cosine_steps
+ )
schedule_fn = optax.join_schedules(
- schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps])
+ schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]
+ )
return schedule_fn
# Create optimizer + LR schedule.
lr_schedule_fn = jax_cosine_warmup(workload.step_hint, hyperparameters)
opt_init_fn, opt_update_fn = nadamw(
- learning_rate=lr_schedule_fn,
- b1=1.0 - hyperparameters['one_minus_beta1'],
- b2=hyperparameters['beta2'],
- eps=1e-8,
- weight_decay=hyperparameters['weight_decay'])
- params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple),
- workload.param_shapes)
+ learning_rate=lr_schedule_fn,
+ b1=1.0 - hyperparameters['one_minus_beta1'],
+ b2=hyperparameters['beta2'],
+ eps=1e-8,
+ weight_decay=hyperparameters['weight_decay'],
+ )
+ params_zeros_like = jax.tree.map(
+ lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes
+ )
optimizer_state = opt_init_fn(params_zeros_like)
return jax_utils.replicate(optimizer_state), opt_update_fn
@functools.partial(
- jax.pmap,
- axis_name='batch',
- in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
- static_broadcasted_argnums=(0, 1),
- donate_argnums=(2, 3, 4))
-def pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- batch,
- rng,
- grad_clip,
- label_smoothing):
-
+ jax.pmap,
+ axis_name='batch',
+ in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
+ static_broadcasted_argnums=(0, 1),
+ donate_argnums=(2, 3, 4),
+)
+def pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ rng,
+ grad_clip,
+ label_smoothing,
+):
def _loss_fn(params):
"""Loss function used for training."""
logits, new_model_state = workload.model_fn(
- params,
- batch,
- model_state,
- spec.ForwardPassMode.TRAIN,
- rng,
- update_batch_norm=True)
+ params,
+ batch,
+ model_state,
+ spec.ForwardPassMode.TRAIN,
+ rng,
+ update_batch_norm=True,
+ )
loss_dict = workload.loss_fn(
- label_batch=batch['targets'],
- logits_batch=logits,
- mask_batch=batch.get('weights'),
- label_smoothing=label_smoothing)
+ label_batch=batch['targets'],
+ logits_batch=logits,
+ mask_batch=batch.get('weights'),
+ label_smoothing=label_smoothing,
+ )
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
return summed_loss, (n_valid_examples, new_model_state)
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
(summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn(
- current_param_container)
+ current_param_container
+ )
# Get correct global mean loss and grad.
(summed_loss, n_valid_examples, grad) = lax.psum(
- (summed_loss, n_valid_examples, grad), axis_name='batch')
+ (summed_loss, n_valid_examples, grad), axis_name='batch'
+ )
loss = summed_loss / n_valid_examples
grad = jax.tree.map(lambda x: x / n_valid_examples, grad)
grad_norm = jnp.sqrt(
- sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)))
+ sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))
+ )
if grad_clip is not None:
grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS)
grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0)
grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad)
- updates, new_optimizer_state = opt_update_fn(grad, optimizer_state,
- current_param_container)
+ updates, new_optimizer_state = opt_update_fn(
+ grad, optimizer_state, current_param_container
+ )
updated_params = optax.apply_updates(current_param_container, updates)
return new_optimizer_state, updated_params, new_model_state, loss, grad_norm
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -296,37 +319,43 @@ def update_params(
grad_clip = hyperparameters['grad_clip']
else:
grad_clip = None
- outputs = pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- batch,
- per_device_rngs,
- grad_clip,
- label_smoothing)
+ outputs = pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ per_device_rngs,
+ grad_clip,
+ label_smoothing,
+ )
new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs
# Log loss, grad_norm.
if global_step % 100 == 0 and workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
- {
- 'loss': loss[0],
- 'grad_norm': grad_norm[0],
- }, global_step)
+ {
+ 'loss': loss[0],
+ 'grad_norm': grad_norm[0],
+ },
+ global_step,
+ )
return (new_optimizer_state, opt_update_fn), new_params, new_model_state
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -366,14 +395,16 @@ def get_batch_size(workload_name):
raise ValueError(f'Unsupported workload name: {workload_name}.')
-def data_selection(workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+def data_selection(
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py
index ffe854a0e..99d996bb9 100644
--- a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py
+++ b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py
@@ -4,15 +4,17 @@
# isort: off
# We have to turn off isort here to resolve a conflict between isort and yapf.
-from typing import (Any,
- Callable,
- Dict,
- Iterator,
- List,
- NamedTuple,
- Optional,
- Tuple,
- Union)
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Iterator,
+ List,
+ NamedTuple,
+ Optional,
+ Tuple,
+ Union,
+)
# isort: on
import chex
@@ -27,27 +29,26 @@
_GRAD_CLIP_EPS = 1e-6
HPARAMS = {
- "dropout_rate": 0.1,
- "learning_rate": 0.0017486387539278373,
- "one_minus_beta1": 0.06733926164,
- "beta2": 0.9955159689799007,
- "weight_decay": 0.08121616522670176,
- "warmup_factor": 0.02
+ 'dropout_rate': 0.1,
+ 'learning_rate': 0.0017486387539278373,
+ 'one_minus_beta1': 0.06733926164,
+ 'beta2': 0.9955159689799007,
+ 'weight_decay': 0.08121616522670176,
+ 'warmup_factor': 0.02,
}
# Forked from
# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/alias.py
def nadamw(
- learning_rate: Union[float, optax.Schedule],
- b1: float = 0.9,
- b2: float = 0.999,
- eps: float = 1e-8,
- eps_root: float = 0.0,
- debias: bool = True,
- weight_decay: float = 0.0,
- weight_decay_mask: Optional[Union[Any, Callable[[optax.Params],
- Any]]] = None,
+ learning_rate: Union[float, optax.Schedule],
+ b1: float = 0.9,
+ b2: float = 0.999,
+ eps: float = 1e-8,
+ eps_root: float = 0.0,
+ debias: bool = True,
+ weight_decay: float = 0.0,
+ weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], Any]]] = None,
) -> optax.GradientTransformation:
"""Rescale updates according to the NAdam algorithm.
@@ -82,19 +83,22 @@ def nadamw(
An (init_fn, update_fn) tuple.
"""
return optax.chain(
- scale_by_nadam(b1, b2, eps, eps_root, debias),
- optax.add_decayed_weights(weight_decay, weight_decay_mask),
- scale_by_learning_rate(learning_rate))
+ scale_by_nadam(b1, b2, eps, eps_root, debias),
+ optax.add_decayed_weights(weight_decay, weight_decay_mask),
+ scale_by_learning_rate(learning_rate),
+ )
# All functions below are forked from
# github.com/google/init2winit/blob/master/init2winit/optimizer_lib/transform.py
-def scale_by_nadam(b1: float = 0.9,
- b2: float = 0.999,
- eps: float = 1e-8,
- eps_root: float = 0.0,
- debias: bool = True,
- power: float = 0.5) -> optax.GradientTransformation:
+def scale_by_nadam(
+ b1: float = 0.9,
+ b2: float = 0.999,
+ eps: float = 1e-8,
+ eps_root: float = 0.0,
+ debias: bool = True,
+ power: float = 0.5,
+) -> optax.GradientTransformation:
"""Rescale updates according to the NAdam algorithm.
References:
@@ -133,7 +137,8 @@ def update_fn(updates, state, params=None):
mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count)
nu_hat = nu if not debias else _bias_correction(nu, b2, count)
updates = jax.tree.map(
- lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat)
+ lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat
+ )
return updates, ScaleByAdamState(count=count, mu=mu, nu=nu)
return optax.GradientTransformation(init_fn, update_fn)
@@ -141,6 +146,7 @@ def update_fn(updates, state, params=None):
class ScaleByAdamState(NamedTuple):
"""State for the NAdam algorithm."""
+
count: chex.Array # shape=(), dtype=jnp.int32.
mu: optax.Updates
nu: optax.Updates
@@ -149,7 +155,8 @@ class ScaleByAdamState(NamedTuple):
def _update_moment(updates, moments, decay, order):
"""Compute the exponential moving average of the `order-th` moment."""
return jax.tree.map(
- lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments)
+ lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments
+ )
def _bias_correction(moment, decay, count):
@@ -165,11 +172,13 @@ def scale_by_learning_rate(learning_rate, flip_sign=True):
return optax.scale(m * learning_rate)
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a NAdamW optimizer and a learning rate schedule."""
del model_params
del model_state
@@ -182,101 +191,115 @@ def jax_cosine_warmup(step_hint: int, hyperparameters):
# Create learning rate schedule.
warmup_steps = int(hyperparameters['warmup_factor'] * step_hint)
warmup_fn = optax.linear_schedule(
- init_value=0.,
- end_value=hyperparameters['learning_rate'],
- transition_steps=warmup_steps)
+ init_value=0.0,
+ end_value=hyperparameters['learning_rate'],
+ transition_steps=warmup_steps,
+ )
cosine_steps = max(step_hint - warmup_steps, 1)
cosine_fn = optax.cosine_decay_schedule(
- init_value=hyperparameters['learning_rate'], decay_steps=cosine_steps)
+ init_value=hyperparameters['learning_rate'], decay_steps=cosine_steps
+ )
schedule_fn = optax.join_schedules(
- schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps])
+ schedules=[warmup_fn, cosine_fn], boundaries=[warmup_steps]
+ )
return schedule_fn
# Create optimizer + LR schedule.
lr_schedule_fn = jax_cosine_warmup(workload.step_hint * 0.75, hyperparameters)
opt_init_fn, opt_update_fn = nadamw(
- learning_rate=lr_schedule_fn,
- b1=1.0 - hyperparameters['one_minus_beta1'],
- b2=hyperparameters['beta2'],
- eps=1e-8,
- weight_decay=hyperparameters['weight_decay'])
- params_zeros_like = jax.tree.map(lambda s: jnp.zeros(s.shape_tuple),
- workload.param_shapes)
+ learning_rate=lr_schedule_fn,
+ b1=1.0 - hyperparameters['one_minus_beta1'],
+ b2=hyperparameters['beta2'],
+ eps=1e-8,
+ weight_decay=hyperparameters['weight_decay'],
+ )
+ params_zeros_like = jax.tree.map(
+ lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes
+ )
optimizer_state = opt_init_fn(params_zeros_like)
return jax_utils.replicate(optimizer_state), opt_update_fn
@functools.partial(
- jax.pmap,
- axis_name='batch',
- in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
- static_broadcasted_argnums=(0, 1),
- donate_argnums=(2, 3, 4))
-def pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- batch,
- rng,
- grad_clip,
- label_smoothing):
-
+ jax.pmap,
+ axis_name='batch',
+ in_axes=(None, None, 0, 0, 0, 0, 0, None, None),
+ static_broadcasted_argnums=(0, 1),
+ donate_argnums=(2, 3, 4),
+)
+def pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ rng,
+ grad_clip,
+ label_smoothing,
+):
def _loss_fn(params):
"""Loss function used for training."""
logits, new_model_state = workload.model_fn(
- params,
- batch,
- model_state,
- spec.ForwardPassMode.TRAIN,
- rng,
- update_batch_norm=True)
+ params,
+ batch,
+ model_state,
+ spec.ForwardPassMode.TRAIN,
+ rng,
+ update_batch_norm=True,
+ )
loss_dict = workload.loss_fn(
- label_batch=batch['targets'],
- logits_batch=logits,
- mask_batch=batch.get('weights'),
- label_smoothing=label_smoothing)
+ label_batch=batch['targets'],
+ logits_batch=logits,
+ mask_batch=batch.get('weights'),
+ label_smoothing=label_smoothing,
+ )
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
return summed_loss, (n_valid_examples, new_model_state)
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
(summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn(
- current_param_container)
+ current_param_container
+ )
# Get correct global mean loss and grad.
(summed_loss, n_valid_examples, grad) = lax.psum(
- (summed_loss, n_valid_examples, grad), axis_name='batch')
+ (summed_loss, n_valid_examples, grad), axis_name='batch'
+ )
loss = summed_loss / n_valid_examples
grad = jax.tree.map(lambda x: x / n_valid_examples, grad)
grad_norm = jnp.sqrt(
- sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad)))
+ sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))
+ )
if grad_clip is not None:
grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS)
grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0)
grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad)
- updates, new_optimizer_state = opt_update_fn(grad, optimizer_state,
- current_param_container)
+ updates, new_optimizer_state = opt_update_fn(
+ grad, optimizer_state, current_param_container
+ )
updated_params = optax.apply_updates(current_param_container, updates)
return new_optimizer_state, updated_params, new_model_state, loss, grad_norm
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -296,37 +319,43 @@ def update_params(
grad_clip = hyperparameters['grad_clip']
else:
grad_clip = None
- outputs = pmapped_train_step(workload,
- opt_update_fn,
- model_state,
- optimizer_state,
- current_param_container,
- batch,
- per_device_rngs,
- grad_clip,
- label_smoothing)
+ outputs = pmapped_train_step(
+ workload,
+ opt_update_fn,
+ model_state,
+ optimizer_state,
+ current_param_container,
+ batch,
+ per_device_rngs,
+ grad_clip,
+ label_smoothing,
+ )
new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs
# Log loss, grad_norm.
if global_step % 100 == 0 and workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
- {
- 'loss': loss[0],
- 'grad_norm': grad_norm[0],
- }, global_step)
+ {
+ 'loss': loss[0],
+ 'grad_norm': grad_norm[0],
+ },
+ global_step,
+ )
return (new_optimizer_state, opt_update_fn), new_params, new_model_state
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -366,14 +395,16 @@ def get_batch_size(workload_name):
raise ValueError(f'Unsupported workload name: {workload_name}.')
-def data_selection(workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+def data_selection(
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
diff --git a/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py
index 554a28762..cc54e3b4e 100644
--- a/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py
+++ b/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py
@@ -18,12 +18,12 @@
USE_PYTORCH_DDP = pytorch_setup()[0]
HPARAMS = {
- "dropout_rate": 0.1,
- "learning_rate": 0.0017486387539278373,
- "one_minus_beta1": 0.06733926164,
- "beta2": 0.9955159689799007,
- "weight_decay": 0.08121616522670176,
- "warmup_factor": 0.02
+ 'dropout_rate': 0.1,
+ 'learning_rate': 0.0017486387539278373,
+ 'one_minus_beta1': 0.06733926164,
+ 'beta2': 0.9955159689799007,
+ 'weight_decay': 0.08121616522670176,
+ 'warmup_factor': 0.02,
}
HPARAMS = collections.namedtuple('Hyperparameters', HPARAMS.keys())(**HPARAMS)
@@ -32,33 +32,30 @@
class NAdamW(torch.optim.Optimizer):
r"""Implements NAdamW algorithm.
- See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of
- the NAdam algorithm (there is also a comment in the code which highlights
- the only difference of NAdamW and AdamW).
- For further details regarding the algorithm we refer to
- `Decoupled Weight Decay Regularization`_.
-
- Args:
- params (iterable): iterable of parameters to optimize or dicts defining
- parameter groups
- lr (float, optional): learning rate (default: 1e-3)
- betas (Tuple[float, float], optional): coefficients used for computing
- running averages of gradient and its square (default: (0.9, 0.999))
- eps (float, optional): term added to the denominator to improve
- numerical stability (default: 1e-8)
- weight_decay (float, optional): weight decay coefficient (default: 1e-2)
- .. _Decoupled Weight Decay Regularization:
- https://arxiv.org/abs/1711.05101
- .. _On the Convergence of Adam and Beyond:
- https://openreview.net/forum?id=ryQu7f-RZ
+ See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of
+ the NAdam algorithm (there is also a comment in the code which highlights
+ the only difference of NAdamW and AdamW).
+ For further details regarding the algorithm we refer to
+ `Decoupled Weight Decay Regularization`_.
+
+ Args:
+ params (iterable): iterable of parameters to optimize or dicts defining
+ parameter groups
+ lr (float, optional): learning rate (default: 1e-3)
+ betas (Tuple[float, float], optional): coefficients used for computing
+ running averages of gradient and its square (default: (0.9, 0.999))
+ eps (float, optional): term added to the denominator to improve
+ numerical stability (default: 1e-8)
+ weight_decay (float, optional): weight decay coefficient (default: 1e-2)
+ .. _Decoupled Weight Decay Regularization:
+ https://arxiv.org/abs/1711.05101
+ .. _On the Convergence of Adam and Beyond:
+ https://openreview.net/forum?id=ryQu7f-RZ
"""
- def __init__(self,
- params,
- lr=1e-3,
- betas=(0.9, 0.999),
- eps=1e-8,
- weight_decay=1e-2):
+ def __init__(
+ self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2
+ ):
if not 0.0 <= lr:
raise ValueError(f'Invalid learning rate: {lr}')
if not 0.0 <= eps:
@@ -70,7 +67,10 @@ def __init__(self,
if not 0.0 <= weight_decay:
raise ValueError(f'Invalid weight_decay value: {weight_decay}')
defaults = {
- 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay
+ 'lr': lr,
+ 'betas': betas,
+ 'eps': eps,
+ 'weight_decay': weight_decay,
}
super().__init__(params, defaults)
@@ -78,7 +78,8 @@ def __setstate__(self, state):
super().__setstate__(state)
state_values = list(self.state.values())
step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
- state_values[0]['step'])
+ state_values[0]['step']
+ )
if not step_is_tensor:
for s in state_values:
s['step'] = torch.tensor(float(s['step']))
@@ -87,9 +88,9 @@ def __setstate__(self, state):
def step(self, closure=None):
"""Performs a single optimization step.
- Args:
- closure (callable, optional): A closure that reevaluates the model
- and returns the loss.
+ Args:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
"""
self._cuda_graph_capture_health_check()
@@ -118,51 +119,57 @@ def step(self, closure=None):
# State initialization
if len(state) == 0:
- state['step'] = torch.tensor(0.)
+ state['step'] = torch.tensor(0.0)
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(
- p, memory_format=torch.preserve_format)
+ p, memory_format=torch.preserve_format
+ )
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(
- p, memory_format=torch.preserve_format)
+ p, memory_format=torch.preserve_format
+ )
exp_avgs.append(state['exp_avg'])
exp_avg_sqs.append(state['exp_avg_sq'])
state_steps.append(state['step'])
nadamw(
- params_with_grad,
- grads,
- exp_avgs,
- exp_avg_sqs,
- state_steps,
- beta1=beta1,
- beta2=beta2,
- lr=group['lr'],
- weight_decay=group['weight_decay'],
- eps=group['eps'])
+ params_with_grad,
+ grads,
+ exp_avgs,
+ exp_avg_sqs,
+ state_steps,
+ beta1=beta1,
+ beta2=beta2,
+ lr=group['lr'],
+ weight_decay=group['weight_decay'],
+ eps=group['eps'],
+ )
return loss
-def nadamw(params: List[Tensor],
- grads: List[Tensor],
- exp_avgs: List[Tensor],
- exp_avg_sqs: List[Tensor],
- state_steps: List[Tensor],
- beta1: float,
- beta2: float,
- lr: float,
- weight_decay: float,
- eps: float) -> None:
+def nadamw(
+ params: List[Tensor],
+ grads: List[Tensor],
+ exp_avgs: List[Tensor],
+ exp_avg_sqs: List[Tensor],
+ state_steps: List[Tensor],
+ beta1: float,
+ beta2: float,
+ lr: float,
+ weight_decay: float,
+ eps: float,
+) -> None:
r"""Functional API that performs NAdamW algorithm computation.
- See NAdamW class for details.
+ See NAdamW class for details.
"""
if not all(isinstance(t, torch.Tensor) for t in state_steps):
raise RuntimeError(
- 'API has changed, `state_steps` argument must contain a list of' +
- ' singleton tensors')
+ 'API has changed, `state_steps` argument must contain a list of'
+ + ' singleton tensors'
+ )
for i, param in enumerate(params):
grad = grads[i]
@@ -200,11 +207,13 @@ def nadamw(params: List[Tensor],
exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1)
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a NAdamW optimizer and a learning rate schedule."""
del model_state
del rng
@@ -213,44 +222,47 @@ def init_optimizer_state(workload: spec.Workload,
hyperparameters = HPARAMS
optimizer_state = {
- 'optimizer':
- NAdamW(
- model_params.parameters(),
- lr=hyperparameters.learning_rate,
- betas=(1.0 - hyperparameters.one_minus_beta1,
- hyperparameters.beta2),
- eps=1e-8,
- weight_decay=hyperparameters.weight_decay),
+ 'optimizer': NAdamW(
+ model_params.parameters(),
+ lr=hyperparameters.learning_rate,
+ betas=(1.0 - hyperparameters.one_minus_beta1, hyperparameters.beta2),
+ eps=1e-8,
+ weight_decay=hyperparameters.weight_decay,
+ ),
}
def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer):
warmup_steps = int(hyperparameters.warmup_factor * step_hint)
warmup = LinearLR(
- optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps)
+ optimizer, start_factor=1e-10, end_factor=1.0, total_iters=warmup_steps
+ )
cosine_steps = max(step_hint - warmup_steps, 1)
cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps)
return SequentialLR(
- optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps])
+ optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]
+ )
optimizer_state['scheduler'] = pytorch_cosine_warmup(
- workload.step_hint, hyperparameters, optimizer_state['optimizer'])
+ workload.step_hint, hyperparameters, optimizer_state['optimizer']
+ )
return optimizer_state
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -265,26 +277,30 @@ def update_params(
optimizer_state['optimizer'].zero_grad()
logits_batch, new_model_state = workload.model_fn(
- params=current_model,
- augmented_and_preprocessed_input_batch=batch,
- model_state=model_state,
- mode=spec.ForwardPassMode.TRAIN,
- rng=rng,
- update_batch_norm=True)
+ params=current_model,
+ augmented_and_preprocessed_input_batch=batch,
+ model_state=model_state,
+ mode=spec.ForwardPassMode.TRAIN,
+ rng=rng,
+ update_batch_norm=True,
+ )
label_smoothing = (
- hyperparameters.label_smoothing if hasattr(hyperparameters,
- 'label_smoothing') else 0.0)
+ hyperparameters.label_smoothing
+ if hasattr(hyperparameters, 'label_smoothing')
+ else 0.0
+ )
if hasattr(hyperparameters, 'grad_clip'):
grad_clip = hyperparameters.grad_clip
else:
grad_clip = None
loss_dict = workload.loss_fn(
- label_batch=batch['targets'],
- logits_batch=logits_batch,
- mask_batch=batch.get('weights'),
- label_smoothing=label_smoothing)
+ label_batch=batch['targets'],
+ logits_batch=logits_batch,
+ mask_batch=batch.get('weights'),
+ label_smoothing=label_smoothing,
+ )
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
if USE_PYTORCH_DDP:
@@ -297,7 +313,8 @@ def update_params(
if grad_clip is not None:
torch.nn.utils.clip_grad_norm_(
- current_model.parameters(), max_norm=grad_clip)
+ current_model.parameters(), max_norm=grad_clip
+ )
optimizer_state['optimizer'].step()
optimizer_state['scheduler'].step()
@@ -306,31 +323,38 @@ def update_params(
with torch.no_grad():
parameters = [p for p in current_model.parameters() if p.grad is not None]
grad_norm = torch.norm(
- torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2)
+ torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2
+ )
if workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
- {
- 'loss': loss.item(),
- 'grad_norm': grad_norm.item(),
- }, global_step)
- logging.info('%d) loss = %0.3f, grad_norm = %0.3f',
- global_step,
- loss.item(),
- grad_norm.item())
+ {
+ 'loss': loss.item(),
+ 'grad_norm': grad_norm.item(),
+ },
+ global_step,
+ )
+ logging.info(
+ '%d) loss = %0.3f, grad_norm = %0.3f',
+ global_step,
+ loss.item(),
+ grad_norm.item(),
+ )
return (optimizer_state, current_param_container, new_model_state)
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -370,14 +394,16 @@ def get_batch_size(workload_name):
raise ValueError(f'Unsupported workload name: {workload_name}.')
-def data_selection(workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+def data_selection(
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
diff --git a/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py
index e4317fa18..bd065dc06 100644
--- a/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py
+++ b/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py
@@ -18,12 +18,12 @@
USE_PYTORCH_DDP = pytorch_setup()[0]
HPARAMS = {
- "dropout_rate": 0.1,
- "learning_rate": 0.0017486387539278373,
- "one_minus_beta1": 0.06733926164,
- "beta2": 0.9955159689799007,
- "weight_decay": 0.08121616522670176,
- "warmup_factor": 0.02
+ 'dropout_rate': 0.1,
+ 'learning_rate': 0.0017486387539278373,
+ 'one_minus_beta1': 0.06733926164,
+ 'beta2': 0.9955159689799007,
+ 'weight_decay': 0.08121616522670176,
+ 'warmup_factor': 0.02,
}
HPARAMS = collections.namedtuple('Hyperparameters', HPARAMS.keys())(**HPARAMS)
@@ -32,33 +32,30 @@
class NAdamW(torch.optim.Optimizer):
r"""Implements NAdamW algorithm.
- See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of
- the NAdam algorithm (there is also a comment in the code which highlights
- the only difference of NAdamW and AdamW).
- For further details regarding the algorithm we refer to
- `Decoupled Weight Decay Regularization`_.
-
- Args:
- params (iterable): iterable of parameters to optimize or dicts defining
- parameter groups
- lr (float, optional): learning rate (default: 1e-3)
- betas (Tuple[float, float], optional): coefficients used for computing
- running averages of gradient and its square (default: (0.9, 0.999))
- eps (float, optional): term added to the denominator to improve
- numerical stability (default: 1e-8)
- weight_decay (float, optional): weight decay coefficient (default: 1e-2)
- .. _Decoupled Weight Decay Regularization:
- https://arxiv.org/abs/1711.05101
- .. _On the Convergence of Adam and Beyond:
- https://openreview.net/forum?id=ryQu7f-RZ
+ See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of
+ the NAdam algorithm (there is also a comment in the code which highlights
+ the only difference of NAdamW and AdamW).
+ For further details regarding the algorithm we refer to
+ `Decoupled Weight Decay Regularization`_.
+
+ Args:
+ params (iterable): iterable of parameters to optimize or dicts defining
+ parameter groups
+ lr (float, optional): learning rate (default: 1e-3)
+ betas (Tuple[float, float], optional): coefficients used for computing
+ running averages of gradient and its square (default: (0.9, 0.999))
+ eps (float, optional): term added to the denominator to improve
+ numerical stability (default: 1e-8)
+ weight_decay (float, optional): weight decay coefficient (default: 1e-2)
+ .. _Decoupled Weight Decay Regularization:
+ https://arxiv.org/abs/1711.05101
+ .. _On the Convergence of Adam and Beyond:
+ https://openreview.net/forum?id=ryQu7f-RZ
"""
- def __init__(self,
- params,
- lr=1e-3,
- betas=(0.9, 0.999),
- eps=1e-8,
- weight_decay=1e-2):
+ def __init__(
+ self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2
+ ):
if not 0.0 <= lr:
raise ValueError(f'Invalid learning rate: {lr}')
if not 0.0 <= eps:
@@ -70,7 +67,10 @@ def __init__(self,
if not 0.0 <= weight_decay:
raise ValueError(f'Invalid weight_decay value: {weight_decay}')
defaults = {
- 'lr': lr, 'betas': betas, 'eps': eps, 'weight_decay': weight_decay
+ 'lr': lr,
+ 'betas': betas,
+ 'eps': eps,
+ 'weight_decay': weight_decay,
}
super().__init__(params, defaults)
@@ -78,7 +78,8 @@ def __setstate__(self, state):
super().__setstate__(state)
state_values = list(self.state.values())
step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
- state_values[0]['step'])
+ state_values[0]['step']
+ )
if not step_is_tensor:
for s in state_values:
s['step'] = torch.tensor(float(s['step']))
@@ -87,9 +88,9 @@ def __setstate__(self, state):
def step(self, closure=None):
"""Performs a single optimization step.
- Args:
- closure (callable, optional): A closure that reevaluates the model
- and returns the loss.
+ Args:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
"""
self._cuda_graph_capture_health_check()
@@ -118,51 +119,57 @@ def step(self, closure=None):
# State initialization
if len(state) == 0:
- state['step'] = torch.tensor(0.)
+ state['step'] = torch.tensor(0.0)
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(
- p, memory_format=torch.preserve_format)
+ p, memory_format=torch.preserve_format
+ )
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(
- p, memory_format=torch.preserve_format)
+ p, memory_format=torch.preserve_format
+ )
exp_avgs.append(state['exp_avg'])
exp_avg_sqs.append(state['exp_avg_sq'])
state_steps.append(state['step'])
nadamw(
- params_with_grad,
- grads,
- exp_avgs,
- exp_avg_sqs,
- state_steps,
- beta1=beta1,
- beta2=beta2,
- lr=group['lr'],
- weight_decay=group['weight_decay'],
- eps=group['eps'])
+ params_with_grad,
+ grads,
+ exp_avgs,
+ exp_avg_sqs,
+ state_steps,
+ beta1=beta1,
+ beta2=beta2,
+ lr=group['lr'],
+ weight_decay=group['weight_decay'],
+ eps=group['eps'],
+ )
return loss
-def nadamw(params: List[Tensor],
- grads: List[Tensor],
- exp_avgs: List[Tensor],
- exp_avg_sqs: List[Tensor],
- state_steps: List[Tensor],
- beta1: float,
- beta2: float,
- lr: float,
- weight_decay: float,
- eps: float) -> None:
+def nadamw(
+ params: List[Tensor],
+ grads: List[Tensor],
+ exp_avgs: List[Tensor],
+ exp_avg_sqs: List[Tensor],
+ state_steps: List[Tensor],
+ beta1: float,
+ beta2: float,
+ lr: float,
+ weight_decay: float,
+ eps: float,
+) -> None:
r"""Functional API that performs NAdamW algorithm computation.
- See NAdamW class for details.
+ See NAdamW class for details.
"""
if not all(isinstance(t, torch.Tensor) for t in state_steps):
raise RuntimeError(
- 'API has changed, `state_steps` argument must contain a list of' +
- ' singleton tensors')
+ 'API has changed, `state_steps` argument must contain a list of'
+ + ' singleton tensors'
+ )
for i, param in enumerate(params):
grad = grads[i]
@@ -200,11 +207,13 @@ def nadamw(params: List[Tensor],
exp_avg.sub_(grad, alpha=1 - beta1).div_(beta1)
-def init_optimizer_state(workload: spec.Workload,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- rng: spec.RandomState) -> spec.OptimizerState:
+def init_optimizer_state(
+ workload: spec.Workload,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ rng: spec.RandomState,
+) -> spec.OptimizerState:
"""Creates a NAdamW optimizer and a learning rate schedule."""
del model_state
del rng
@@ -213,44 +222,47 @@ def init_optimizer_state(workload: spec.Workload,
hyperparameters = HPARAMS
optimizer_state = {
- 'optimizer':
- NAdamW(
- model_params.parameters(),
- lr=hyperparameters.learning_rate,
- betas=(1.0 - hyperparameters.one_minus_beta1,
- hyperparameters.beta2),
- eps=1e-8,
- weight_decay=hyperparameters.weight_decay),
+ 'optimizer': NAdamW(
+ model_params.parameters(),
+ lr=hyperparameters.learning_rate,
+ betas=(1.0 - hyperparameters.one_minus_beta1, hyperparameters.beta2),
+ eps=1e-8,
+ weight_decay=hyperparameters.weight_decay,
+ ),
}
def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer):
warmup_steps = int(hyperparameters.warmup_factor * step_hint)
warmup = LinearLR(
- optimizer, start_factor=1e-10, end_factor=1., total_iters=warmup_steps)
+ optimizer, start_factor=1e-10, end_factor=1.0, total_iters=warmup_steps
+ )
cosine_steps = max(step_hint - warmup_steps, 1)
cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps)
return SequentialLR(
- optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps])
+ optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps]
+ )
optimizer_state['scheduler'] = pytorch_cosine_warmup(
- workload.step_hint * 0.75, hyperparameters, optimizer_state['optimizer'])
+ workload.step_hint * 0.75, hyperparameters, optimizer_state['optimizer']
+ )
return optimizer_state
def update_params(
- workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- batch: Dict[str, spec.Tensor],
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn:
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ batch: Dict[str, spec.Tensor],
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
del current_params_types
del loss_type
@@ -265,26 +277,30 @@ def update_params(
optimizer_state['optimizer'].zero_grad()
logits_batch, new_model_state = workload.model_fn(
- params=current_model,
- augmented_and_preprocessed_input_batch=batch,
- model_state=model_state,
- mode=spec.ForwardPassMode.TRAIN,
- rng=rng,
- update_batch_norm=True)
+ params=current_model,
+ augmented_and_preprocessed_input_batch=batch,
+ model_state=model_state,
+ mode=spec.ForwardPassMode.TRAIN,
+ rng=rng,
+ update_batch_norm=True,
+ )
label_smoothing = (
- hyperparameters.label_smoothing if hasattr(hyperparameters,
- 'label_smoothing') else 0.0)
+ hyperparameters.label_smoothing
+ if hasattr(hyperparameters, 'label_smoothing')
+ else 0.0
+ )
if hasattr(hyperparameters, 'grad_clip'):
grad_clip = hyperparameters.grad_clip
else:
grad_clip = None
loss_dict = workload.loss_fn(
- label_batch=batch['targets'],
- logits_batch=logits_batch,
- mask_batch=batch.get('weights'),
- label_smoothing=label_smoothing)
+ label_batch=batch['targets'],
+ logits_batch=logits_batch,
+ mask_batch=batch.get('weights'),
+ label_smoothing=label_smoothing,
+ )
summed_loss = loss_dict['summed']
n_valid_examples = loss_dict['n_valid_examples']
if USE_PYTORCH_DDP:
@@ -297,7 +313,8 @@ def update_params(
if grad_clip is not None:
torch.nn.utils.clip_grad_norm_(
- current_model.parameters(), max_norm=grad_clip)
+ current_model.parameters(), max_norm=grad_clip
+ )
optimizer_state['optimizer'].step()
optimizer_state['scheduler'].step()
@@ -306,31 +323,38 @@ def update_params(
with torch.no_grad():
parameters = [p for p in current_model.parameters() if p.grad is not None]
grad_norm = torch.norm(
- torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2)
+ torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2
+ )
if workload.metrics_logger is not None:
workload.metrics_logger.append_scalar_metrics(
- {
- 'loss': loss.item(),
- 'grad_norm': grad_norm.item(),
- }, global_step)
- logging.info('%d) loss = %0.3f, grad_norm = %0.3f',
- global_step,
- loss.item(),
- grad_norm.item())
+ {
+ 'loss': loss.item(),
+ 'grad_norm': grad_norm.item(),
+ },
+ global_step,
+ )
+ logging.info(
+ '%d) loss = %0.3f, grad_norm = %0.3f',
+ global_step,
+ loss.item(),
+ grad_norm.item(),
+ )
return (optimizer_state, current_param_container, new_model_state)
-def prepare_for_eval(workload: spec.Workload,
- current_param_container: spec.ParameterContainer,
- current_params_types: spec.ParameterTypeTree,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- loss_type: spec.LossType,
- optimizer_state: spec.OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: spec.RandomState) -> spec.UpdateReturn:
+def prepare_for_eval(
+ workload: spec.Workload,
+ current_param_container: spec.ParameterContainer,
+ current_params_types: spec.ParameterTypeTree,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ loss_type: spec.LossType,
+ optimizer_state: spec.OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: spec.RandomState,
+) -> spec.UpdateReturn:
"""Return (updated_optimizer_state, updated_params)."""
del workload
del hyperparameters
@@ -370,14 +394,16 @@ def get_batch_size(workload_name):
raise ValueError(f'Unsupported workload name: {workload_name}.')
-def data_selection(workload: spec.Workload,
- input_queue: Iterator[Dict[str, spec.Tensor]],
- optimizer_state: spec.OptimizerState,
- current_param_container: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- hyperparameters: spec.Hyperparameters,
- global_step: int,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+def data_selection(
+ workload: spec.Workload,
+ input_queue: Iterator[Dict[str, spec.Tensor]],
+ optimizer_state: spec.OptimizerState,
+ current_param_container: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ hyperparameters: spec.Hyperparameters,
+ global_step: int,
+ rng: spec.RandomState,
+) -> Dict[str, spec.Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
"""
From 531c99ec5545a01905b53bd2f92c5805dec1576b Mon Sep 17 00:00:00 2001
From: Frank Schneider
Date: Mon, 23 Jun 2025 10:20:52 +0200
Subject: [PATCH 106/123] Format datasets/
---
datasets/dataset_setup.py | 519 ++++++++++++++++-------------
datasets/librispeech_preprocess.py | 74 ++--
datasets/librispeech_tokenizer.py | 42 ++-
3 files changed, 366 insertions(+), 269 deletions(-)
diff --git a/datasets/dataset_setup.py b/datasets/dataset_setup.py
index efe923dbe..e110930cd 100644
--- a/datasets/dataset_setup.py
+++ b/datasets/dataset_setup.py
@@ -72,8 +72,7 @@
from torchvision.datasets import CIFAR10
from algoperf.workloads.wmt import tokenizer
-from algoperf.workloads.wmt.input_pipeline import \
- normalize_feature_names
+from algoperf.workloads.wmt.input_pipeline import normalize_feature_names
from datasets import librispeech_preprocess
from datasets import librispeech_tokenizer
@@ -101,84 +100,96 @@
FASTMRI_TEST_TAR_FILENAME = 'knee_singlecoil_test.tar.xz'
flags.DEFINE_boolean(
- 'interactive_deletion',
- True,
- 'If true, user will be prompted before any files are deleted. If false, no '
- 'files will be deleted.')
+ 'interactive_deletion',
+ True,
+ 'If true, user will be prompted before any files are deleted. If false, no '
+ 'files will be deleted.',
+)
flags.DEFINE_boolean(
- 'all',
- False,
- 'Whether or not to download all datasets. If false, can download some '
- 'combination of datasets by setting the individual dataset flags below.')
-
-flags.DEFINE_boolean('criteo1tb',
- False,
- 'If --all=false, whether or not to download Criteo 1TB.')
-flags.DEFINE_boolean('cifar',
- False,
- 'If --all=false, whether or not to download CIFAR-10.')
-flags.DEFINE_boolean('fastmri',
- False,
- 'If --all=false, whether or not to download FastMRI.')
-flags.DEFINE_boolean('imagenet',
- False,
- 'If --all=false, whether or not to download Imagenet.')
-flags.DEFINE_boolean('librispeech',
- False,
- 'If --all=false, whether or not to download LibriSpeech.')
-flags.DEFINE_boolean('mnist',
- False,
- 'If --all=false, whether or not to download MNIST.')
-flags.DEFINE_boolean('ogbg',
- False,
- 'If --all=false, whether or not to download OGBG.')
-flags.DEFINE_boolean('wmt',
- False,
- 'If --all=false, whether or not to download WMT.')
+ 'all',
+ False,
+ 'Whether or not to download all datasets. If false, can download some '
+ 'combination of datasets by setting the individual dataset flags below.',
+)
+
+flags.DEFINE_boolean(
+ 'criteo1tb', False, 'If --all=false, whether or not to download Criteo 1TB.'
+)
+flags.DEFINE_boolean(
+ 'cifar', False, 'If --all=false, whether or not to download CIFAR-10.'
+)
+flags.DEFINE_boolean(
+ 'fastmri', False, 'If --all=false, whether or not to download FastMRI.'
+)
+flags.DEFINE_boolean(
+ 'imagenet', False, 'If --all=false, whether or not to download Imagenet.'
+)
+flags.DEFINE_boolean(
+ 'librispeech',
+ False,
+ 'If --all=false, whether or not to download LibriSpeech.',
+)
+flags.DEFINE_boolean(
+ 'mnist', False, 'If --all=false, whether or not to download MNIST.'
+)
+flags.DEFINE_boolean(
+ 'ogbg', False, 'If --all=false, whether or not to download OGBG.'
+)
+flags.DEFINE_boolean(
+ 'wmt', False, 'If --all=false, whether or not to download WMT.'
+)
flags.DEFINE_string(
- 'data_dir',
- '~/data',
- 'The path to the folder where datasets should be downloaded.')
+ 'data_dir',
+ '~/data',
+ 'The path to the folder where datasets should be downloaded.',
+)
flags.DEFINE_string(
- 'temp_dir',
- '/tmp/mlcommons',
- 'A local path to a folder where temp files can be downloaded.')
+ 'temp_dir',
+ '/tmp/mlcommons',
+ 'A local path to a folder where temp files can be downloaded.',
+)
flags.DEFINE_string(
- 'imagenet_train_url',
- None,
- 'Only necessary if you want this script to `wget` the ImageNet train '
- 'split. If not, you can supply the path to --data_dir in '
- 'submission_runner.py.')
+ 'imagenet_train_url',
+ None,
+ 'Only necessary if you want this script to `wget` the ImageNet train '
+ 'split. If not, you can supply the path to --data_dir in '
+ 'submission_runner.py.',
+)
flags.DEFINE_string(
- 'imagenet_val_url',
- None,
- 'Only necessary if you want this script to `wget` the ImageNet validation '
- 'split. If not, you can supply the path to --data_dir in '
- 'submission_runner.py.')
+ 'imagenet_val_url',
+ None,
+ 'Only necessary if you want this script to `wget` the ImageNet validation '
+ 'split. If not, you can supply the path to --data_dir in '
+ 'submission_runner.py.',
+)
flags.DEFINE_string(
- 'fastmri_knee_singlecoil_train_url',
- None,
- 'Only necessary if you want this script to `wget` the FastMRI train '
- 'split. If not, you can supply the path to --data_dir in '
- 'submission_runner.py.')
+ 'fastmri_knee_singlecoil_train_url',
+ None,
+ 'Only necessary if you want this script to `wget` the FastMRI train '
+ 'split. If not, you can supply the path to --data_dir in '
+ 'submission_runner.py.',
+)
flags.DEFINE_string(
- 'fastmri_knee_singlecoil_val_url',
- None,
- 'Only necessary if you want this script to `wget` the FastMRI validation '
- 'split. If not, you can supply the path to --data_dir in '
- 'submission_runner.py.')
+ 'fastmri_knee_singlecoil_val_url',
+ None,
+ 'Only necessary if you want this script to `wget` the FastMRI validation '
+ 'split. If not, you can supply the path to --data_dir in '
+ 'submission_runner.py.',
+)
flags.DEFINE_string(
- 'fastmri_knee_singlecoil_test_url',
- None,
- 'Only necessary if you want this script to `wget` the FastMRI test '
- 'split. If not, you can supply the path to --data_dir in '
- 'submission_runner.py.')
+ 'fastmri_knee_singlecoil_test_url',
+ None,
+ 'Only necessary if you want this script to `wget` the FastMRI test '
+ 'split. If not, you can supply the path to --data_dir in '
+ 'submission_runner.py.',
+)
flags.DEFINE_integer(
- 'num_decompression_threads',
- 8,
- 'The number of threads to use in parallel when decompressing.')
+ 'num_decompression_threads',
+ 8,
+ 'The number of threads to use in parallel when decompressing.',
+)
flags.DEFINE_string('framework', None, 'Can be either jax or pytorch.')
@@ -186,7 +197,7 @@
FLAGS = flags.FLAGS
-os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
+os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
def _maybe_mkdir(d):
@@ -198,8 +209,10 @@ def _maybe_prompt_for_deletion(paths, interactive_deletion):
if not interactive_deletion:
return
files_for_deletion = '\n'.join(paths)
- logging.info('\n\n\nWARNING: the following temp files will be DELETED:'
- f'\n{files_for_deletion}')
+ logging.info(
+ '\n\n\nWARNING: the following temp files will be DELETED:'
+ f'\n{files_for_deletion}'
+ )
delete_str = input('Confirm deletion? [y/N]: ')
if delete_str.lower() == 'y':
del_cmd = 'rm ' + ' '.join(f'"{s}"' for s in paths)
@@ -225,8 +238,9 @@ def _download_url(url, data_dir, name=None):
if os.path.exists(file_path):
while True:
- overwrite = input('File already exists {}.\n Overwrite? (Y/n)'.format(
- file_path)).lower()
+ overwrite = input(
+ 'File already exists {}.\n Overwrite? (Y/n)'.format(file_path)
+ ).lower()
if overwrite in ['y', 'n']:
break
logging.info('Invalid response. Try again.')
@@ -240,17 +254,18 @@ def _download_url(url, data_dir, name=None):
progress_bar.update(chunk_size_in_mib)
f.write(chunk)
progress_bar.close()
- if (progress_bar.total != 0 and progress_bar.n != progress_bar.total):
+ if progress_bar.total != 0 and progress_bar.n != progress_bar.total:
raise RuntimeError(
- ('Download corrupted, size {n} MiB from {url} does not match '
- 'expected size {size} MiB').format(
- url=url, n=progress_bar.n, size=progress_bar.total))
+ (
+ 'Download corrupted, size {n} MiB from {url} does not match '
+ 'expected size {size} MiB'
+ ).format(url=url, n=progress_bar.n, size=progress_bar.total)
+ )
-def download_criteo1tb(data_dir,
- tmp_dir,
- num_decompression_threads,
- interactive_deletion):
+def download_criteo1tb(
+ data_dir, tmp_dir, num_decompression_threads, interactive_deletion
+):
criteo_dir = os.path.join(data_dir, 'criteo1tb')
tmp_criteo_dir = os.path.join(tmp_dir, 'criteo1tb')
_maybe_mkdir(criteo_dir)
@@ -258,47 +273,56 @@ def download_criteo1tb(data_dir,
# Forked from
# https://github.com/iamleot/transferwee/blob/master/transferwee.py.
- user_agent = ('Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:102.0) '
- 'Gecko/20100101 Firefox/102.0')
+ user_agent = (
+ 'Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:102.0) '
+ 'Gecko/20100101 Firefox/102.0'
+ )
criteo_wetransfer_url = (
- 'https://criteo.wetransfer.com/downloads/'
- '4bbea9b4a54baddea549d71271a38e2c20230428071257/d4f0d2')
+ 'https://criteo.wetransfer.com/downloads/'
+ '4bbea9b4a54baddea549d71271a38e2c20230428071257/d4f0d2'
+ )
_, _, transfer_id, security_hash = urllib.parse.urlparse(
- criteo_wetransfer_url).path.split('/')
+ criteo_wetransfer_url
+ ).path.split('/')
session = requests.Session()
- session.headers.update({
+ session.headers.update(
+ {
'User-Agent': user_agent,
'x-requested-with': 'XMLHttpRequest',
- })
+ }
+ )
r = session.get('https://wetransfer.com/')
m = re.search('name="csrf-token" content="([^"]+)"', r.text)
if m:
session.headers.update({'x-csrf-token': m.group(1)})
get_url_request = session.post(
- f'https://wetransfer.com/api/v4/transfers/{transfer_id}/download',
- json={
- 'intent': 'entire_transfer',
- 'security_hash': security_hash,
- })
+ f'https://wetransfer.com/api/v4/transfers/{transfer_id}/download',
+ json={
+ 'intent': 'entire_transfer',
+ 'security_hash': security_hash,
+ },
+ )
session.close()
download_url = get_url_request.json().get('direct_link')
logging.info(f'Downloading ~342GB Criteo 1TB data .zip file:\n{download_url}')
download_request = requests.get( # pylint: disable=missing-timeout
- download_url,
- headers={'User-Agent': user_agent},
- stream=True)
+ download_url, headers={'User-Agent': user_agent}, stream=True
+ )
all_days_zip_filepath = os.path.join(tmp_criteo_dir, 'all_days.zip')
if not FLAGS.skip_download:
download = True
if os.path.exists(all_days_zip_filepath):
while True:
- overwrite = input('File already exists {}.\n Overwrite? (Y/n)'.format(
- all_days_zip_filepath)).lower()
+ overwrite = input(
+ 'File already exists {}.\n Overwrite? (Y/n)'.format(
+ all_days_zip_filepath
+ )
+ ).lower()
if overwrite in ['y', 'n']:
break
logging.info('Invalid response. Try again.')
@@ -324,8 +348,10 @@ def download_criteo1tb(data_dir,
input_path = os.path.join(tmp_criteo_dir, f'day_{day}.gz')
gz_paths.append(input_path)
unzipped_path = os.path.join(criteo_dir, f'day_{day}.csv')
- unzip_cmd = (f'pigz -d -c -p{num_decompression_threads} "{input_path}" > '
- f'"{unzipped_path}"')
+ unzip_cmd = (
+ f'pigz -d -c -p{num_decompression_threads} "{input_path}" > '
+ f'"{unzipped_path}"'
+ )
logging.info(f'Running Criteo unzip command for day {day}:\n{unzip_cmd}')
processes.append(subprocess.Popen(unzip_cmd, shell=True))
for p in processes:
@@ -341,8 +367,7 @@ def download_criteo1tb(data_dir,
unzipped_path = os.path.join(criteo_dir, f'day_{day}.csv')
unzipped_paths.append(unzipped_path)
split_path = os.path.join(criteo_dir, f'day_{day}_')
- split_cmd = ('split -a 2 -d -l 5000000 '
- f'"{unzipped_path}" "{split_path}"')
+ split_cmd = f'split -a 2 -d -l 5000000 "{unzipped_path}" "{split_path}"'
logging.info(f'Running Criteo 1TB split command:\n{split_cmd}')
batch_processes.append(subprocess.Popen(split_cmd, shell=True))
for p in batch_processes:
@@ -362,45 +387,50 @@ def download_cifar(data_dir, framework):
def extract_filename_from_url(url, start_str='knee', end_str='.xz'):
- """ The url filenames are sometimes couched within a urldefense+aws access id
+ """The url filenames are sometimes couched within a urldefense+aws access id
etc. string. Unfortunately querying the content disposition in requests fails
(not provided)... so fast search is done here within the url.
- """
+ """
failure = -1
start = url.find(start_str)
end = url.find(end_str)
if failure in (start, end):
raise ValueError(
- f'Unable to locate filename wrapped in {start_str}--{end_str} in {url}')
+ f'Unable to locate filename wrapped in {start_str}--{end_str} in {url}'
+ )
end += len(end_str) # make it inclusive
return url[start:end]
-def download_fastmri(data_dir,
- fastmri_train_url,
- fastmri_val_url,
- fastmri_test_url):
+def download_fastmri(
+ data_dir, fastmri_train_url, fastmri_val_url, fastmri_test_url
+):
data_dir = os.path.join(data_dir, 'fastmri')
# Download fastmri train dataset
knee_train_filename = extract_filename_from_url(fastmri_train_url)
logging.info(
- 'Downloading fastmri train dataset from {}'.format(fastmri_train_url))
+ 'Downloading fastmri train dataset from {}'.format(fastmri_train_url)
+ )
_download_url(
- url=fastmri_train_url, data_dir=data_dir, name=knee_train_filename)
+ url=fastmri_train_url, data_dir=data_dir, name=knee_train_filename
+ )
# Download fastmri val dataset
knee_val_filename = extract_filename_from_url(fastmri_val_url)
logging.info(
- 'Downloading fastmri val dataset from {}'.format(fastmri_val_url))
+ 'Downloading fastmri val dataset from {}'.format(fastmri_val_url)
+ )
_download_url(url=fastmri_val_url, data_dir=data_dir, name=knee_val_filename)
# Download fastmri test dataset
knee_test_filename = extract_filename_from_url(fastmri_test_url)
logging.info(
- 'Downloading fastmri test dataset from {}'.format(fastmri_test_url))
+ 'Downloading fastmri test dataset from {}'.format(fastmri_test_url)
+ )
_download_url(
- url=fastmri_test_url, data_dir=data_dir, name=knee_test_filename)
+ url=fastmri_test_url, data_dir=data_dir, name=knee_test_filename
+ )
return data_dir
@@ -432,18 +462,18 @@ def setup_fastmri(data_dir):
# Rename folders to match what the workload expects
os.rename(
- os.path.join(data_dir, "singlecoil_train"),
- os.path.join(data_dir, "knee_singlecoil_train"),
+ os.path.join(data_dir, 'singlecoil_train'),
+ os.path.join(data_dir, 'knee_singlecoil_train'),
)
os.rename(
- os.path.join(data_dir, "singlecoil_val"),
- os.path.join(data_dir, "knee_singlecoil_val"),
+ os.path.join(data_dir, 'singlecoil_val'),
+ os.path.join(data_dir, 'knee_singlecoil_val'),
)
os.rename(
- os.path.join(data_dir, "singlecoil_test"),
- os.path.join(data_dir, "knee_singlecoil_test"),
+ os.path.join(data_dir, 'singlecoil_test'),
+ os.path.join(data_dir, 'knee_singlecoil_test'),
)
- logging.info("Set up fastMRI dataset complete")
+ logging.info('Set up fastMRI dataset complete')
def download_imagenet(data_dir, imagenet_train_url, imagenet_val_url):
@@ -456,26 +486,32 @@ def download_imagenet(data_dir, imagenet_train_url, imagenet_val_url):
# been moved to the manual_download_dir.
# Get paths in manual_download_dir.
imagenet_jax_data_dir = os.path.join(data_dir, 'jax')
- manual_download_dir = os.path.join(imagenet_jax_data_dir,
- 'downloads',
- 'manual')
- imagenet_train_download_filepath = os.path.join(manual_download_dir,
- IMAGENET_TRAIN_TAR_FILENAME)
- imagenet_val_download_filepath = os.path.join(manual_download_dir,
- IMAGENET_VAL_TAR_FILENAME)
+ manual_download_dir = os.path.join(
+ imagenet_jax_data_dir, 'downloads', 'manual'
+ )
+ imagenet_train_download_filepath = os.path.join(
+ manual_download_dir, IMAGENET_TRAIN_TAR_FILENAME
+ )
+ imagenet_val_download_filepath = os.path.join(
+ manual_download_dir, IMAGENET_VAL_TAR_FILENAME
+ )
# Download imagenet train dataset
if not os.path.exists(imagenet_train_filepath) and not os.path.exists(
- imagenet_train_download_filepath):
+ imagenet_train_download_filepath
+ ):
logging.info(
- 'Downloading imagenet train dataset from {}'.format(imagenet_train_url))
+ 'Downloading imagenet train dataset from {}'.format(imagenet_train_url)
+ )
_download_url(url=imagenet_train_url, data_dir=data_dir)
# Download imagenet val dataset
if not os.path.exists(imagenet_val_filepath) and not os.path.exists(
- imagenet_val_download_filepath):
- logging.info('Downloading imagenet validation dataset from {}'.format(
- imagenet_val_url))
+ imagenet_val_download_filepath
+ ):
+ logging.info(
+ 'Downloading imagenet validation dataset from {}'.format(imagenet_val_url)
+ )
_download_url(url=imagenet_val_url, data_dir=data_dir)
# Download imagenet test set
@@ -501,31 +537,40 @@ def setup_imagenet_jax(data_dir):
# Setup jax dataset dir
imagenet_jax_data_dir = os.path.join(data_dir, 'jax')
- manual_download_dir = os.path.join(imagenet_jax_data_dir,
- 'downloads',
- 'manual')
+ manual_download_dir = os.path.join(
+ imagenet_jax_data_dir, 'downloads', 'manual'
+ )
os.makedirs(manual_download_dir, exist_ok=True)
# Copy tar file into jax/downloads/manual
logging.info('Checking if tar files already exists in jax/downloads/manual.')
if not os.path.exists(
- os.path.join(manual_download_dir, IMAGENET_TRAIN_TAR_FILENAME)):
- logging.info('Moving {} to {}'.format(train_tar_file_path,
- manual_download_dir))
+ os.path.join(manual_download_dir, IMAGENET_TRAIN_TAR_FILENAME)
+ ):
+ logging.info(
+ 'Moving {} to {}'.format(train_tar_file_path, manual_download_dir)
+ )
shutil.move(train_tar_file_path, manual_download_dir)
if not os.path.exists(
- os.path.join(manual_download_dir, IMAGENET_VAL_TAR_FILENAME)):
- logging.info('Moving {} to {}'.format(val_tar_file_path,
- manual_download_dir))
+ os.path.join(manual_download_dir, IMAGENET_VAL_TAR_FILENAME)
+ ):
+ logging.info(
+ 'Moving {} to {}'.format(val_tar_file_path, manual_download_dir)
+ )
shutil.move(val_tar_file_path, manual_download_dir)
if not os.path.exists(os.path.join(imagenet_jax_data_dir, 'imagenet_v2')):
- logging.info('Moving imagenet_v2 to {}'.format(
- os.path.join(imagenet_jax_data_dir, 'imagenet_v2')))
- shutil.move(test_dir_path,
- os.path.join(imagenet_jax_data_dir, 'imagenet_v2'))
+ logging.info(
+ 'Moving imagenet_v2 to {}'.format(
+ os.path.join(imagenet_jax_data_dir, 'imagenet_v2')
+ )
+ )
+ shutil.move(
+ test_dir_path, os.path.join(imagenet_jax_data_dir, 'imagenet_v2')
+ )
logging.info('Preparing imagenet data.')
ds_builder = tfds.builder(
- 'imagenet2012:5.1.0', data_dir=os.path.join(imagenet_jax_data_dir))
+ 'imagenet2012:5.1.0', data_dir=os.path.join(imagenet_jax_data_dir)
+ )
ds_builder.download_and_prepare()
logging.info('Set up imagenet dataset for jax framework complete')
@@ -539,14 +584,18 @@ def setup_imagenet_pytorch(data_dir):
manual_download_dir = os.path.join(data_dir, 'jax', 'downloads', 'manual')
if not os.path.exists(train_tar_file_path):
if os.path.exists(
- os.path.join(manual_download_dir, IMAGENET_TRAIN_TAR_FILENAME)):
- train_tar_file_path = os.path.join(manual_download_dir,
- IMAGENET_TRAIN_TAR_FILENAME)
+ os.path.join(manual_download_dir, IMAGENET_TRAIN_TAR_FILENAME)
+ ):
+ train_tar_file_path = os.path.join(
+ manual_download_dir, IMAGENET_TRAIN_TAR_FILENAME
+ )
if not os.path.exists(val_tar_file_path):
if os.path.exists(
- os.path.join(manual_download_dir, IMAGENET_VAL_TAR_FILENAME)):
- val_tar_file_path = os.path.join(manual_download_dir,
- IMAGENET_VAL_TAR_FILENAME)
+ os.path.join(manual_download_dir, IMAGENET_VAL_TAR_FILENAME)
+ ):
+ val_tar_file_path = os.path.join(
+ manual_download_dir, IMAGENET_VAL_TAR_FILENAME
+ )
# Setup pytorch dataset dir
imagenet_pytorch_data_dir = os.path.join(data_dir, 'pytorch')
@@ -557,56 +606,68 @@ def setup_imagenet_pytorch(data_dir):
# Move tar files and imagenet_v2 into pytorch directory
if not os.path.exists(
- os.path.join(imagenet_pytorch_data_dir, IMAGENET_TRAIN_TAR_FILENAME)):
- logging.info('Moving {} to {}'.format(train_tar_file_path,
- imagenet_pytorch_data_dir))
+ os.path.join(imagenet_pytorch_data_dir, IMAGENET_TRAIN_TAR_FILENAME)
+ ):
+ logging.info(
+ 'Moving {} to {}'.format(train_tar_file_path, imagenet_pytorch_data_dir)
+ )
shutil.move(train_tar_file_path, imagenet_pytorch_data_dir)
if not os.path.exists(
- os.path.join(imagenet_pytorch_data_dir, IMAGENET_VAL_TAR_FILENAME)):
- logging.info('Moving {} to {}'.format(val_tar_file_path,
- imagenet_pytorch_data_dir))
+ os.path.join(imagenet_pytorch_data_dir, IMAGENET_VAL_TAR_FILENAME)
+ ):
+ logging.info(
+ 'Moving {} to {}'.format(val_tar_file_path, imagenet_pytorch_data_dir)
+ )
shutil.move(val_tar_file_path, imagenet_pytorch_data_dir)
if not os.path.exists(os.path.join(imagenet_pytorch_data_dir, 'imagenet_v2')):
- logging.info('Moving imagenet_v2 to {}'.format(
- os.path.join(imagenet_pytorch_data_dir, 'imagenet_v2')))
- shutil.move(test_dir_path,
- os.path.join(imagenet_pytorch_data_dir, 'imagenet_v2'))
+ logging.info(
+ 'Moving imagenet_v2 to {}'.format(
+ os.path.join(imagenet_pytorch_data_dir, 'imagenet_v2')
+ )
+ )
+ shutil.move(
+ test_dir_path, os.path.join(imagenet_pytorch_data_dir, 'imagenet_v2')
+ )
# Extract train data\
logging.info('Extracting imagenet train data')
extract(
- os.path.join(imagenet_pytorch_data_dir, IMAGENET_TRAIN_TAR_FILENAME),
- os.path.join(imagenet_pytorch_data_dir, 'train'),
- mode='r:')
+ os.path.join(imagenet_pytorch_data_dir, IMAGENET_TRAIN_TAR_FILENAME),
+ os.path.join(imagenet_pytorch_data_dir, 'train'),
+ mode='r:',
+ )
train_tar_filenames = os.listdir(
- os.path.join(imagenet_pytorch_data_dir, 'train'))
+ os.path.join(imagenet_pytorch_data_dir, 'train')
+ )
for tar_filename in train_tar_filenames:
if tar_filename.endswith('.tar'):
dir_name = tar_filename[:-4]
extract(
- os.path.join(imagenet_pytorch_data_dir, 'train', tar_filename),
- os.path.join(imagenet_pytorch_data_dir, 'train', dir_name),
- mode='r:')
+ os.path.join(imagenet_pytorch_data_dir, 'train', tar_filename),
+ os.path.join(imagenet_pytorch_data_dir, 'train', dir_name),
+ mode='r:',
+ )
# Extract val data
logging.info('Extracting imagenet val data')
extract(
- os.path.join(imagenet_pytorch_data_dir, IMAGENET_VAL_TAR_FILENAME),
- os.path.join(imagenet_pytorch_data_dir, 'val'),
- mode='r:')
+ os.path.join(imagenet_pytorch_data_dir, IMAGENET_VAL_TAR_FILENAME),
+ os.path.join(imagenet_pytorch_data_dir, 'val'),
+ mode='r:',
+ )
valprep_command = [
- 'wget',
- '-qO-',
- 'https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh'
+ 'wget',
+ '-qO-',
+ 'https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh',
]
valprep_download = subprocess.Popen(valprep_command, stdout=subprocess.PIPE)
- valprep_process = subprocess.Popen(['bash'],
- stdin=valprep_download.stdout,
- cwd=os.path.expanduser(
- os.path.join(imagenet_pytorch_data_dir,
- 'val')))
+ valprep_process = subprocess.Popen(
+ ['bash'],
+ stdin=valprep_download.stdout,
+ cwd=os.path.expanduser(os.path.join(imagenet_pytorch_data_dir, 'val')),
+ )
valprep_download.stdout.close()
valprep_process.communicate()
logging.info('Set up imagenet dataset for pytorch framework complete')
@@ -614,8 +675,8 @@ def setup_imagenet_pytorch(data_dir):
def download_imagenet_v2(data_dir):
tfds.builder(
- 'imagenet_v2/matched-frequency:3.0.0',
- data_dir=data_dir).download_and_prepare()
+ 'imagenet_v2/matched-frequency:3.0.0', data_dir=data_dir
+ ).download_and_prepare()
def download_librispeech(data_dir, tmp_dir):
@@ -634,41 +695,46 @@ def download_librispeech(data_dir, tmp_dir):
if split == 'test' and version == 'other':
continue
wget_cmd = (
- f'wget --directory-prefix={tmp_librispeech_dir} '
- f'http://www.openslr.org/resources/12/{split}-{version}.tar.gz')
+ f'wget --directory-prefix={tmp_librispeech_dir} '
+ f'http://www.openslr.org/resources/12/{split}-{version}.tar.gz'
+ )
subprocess.Popen(wget_cmd, shell=True).communicate()
tar_path = os.path.join(tmp_librispeech_dir, f'{split}-{version}.tar.gz')
subprocess.Popen(
- f'tar xzvf {tar_path} --directory {tmp_librispeech_dir}',
- shell=True).communicate()
+ f'tar xzvf {tar_path} --directory {tmp_librispeech_dir}', shell=True
+ ).communicate()
tars = [
- 'raw-metadata.tar.gz',
- 'train-clean-100.tar.gz',
- 'train-clean-360.tar.gz',
- 'train-other-500.tar.gz',
+ 'raw-metadata.tar.gz',
+ 'train-clean-100.tar.gz',
+ 'train-clean-360.tar.gz',
+ 'train-other-500.tar.gz',
]
for tar_filename in tars:
- wget_cmd = (f'wget --directory-prefix={tmp_librispeech_dir} '
- f'http://www.openslr.org/resources/12/{tar_filename}')
+ wget_cmd = (
+ f'wget --directory-prefix={tmp_librispeech_dir} '
+ f'http://www.openslr.org/resources/12/{tar_filename}'
+ )
subprocess.Popen(wget_cmd, shell=True).communicate()
tar_path = os.path.join(tmp_librispeech_dir, tar_filename)
subprocess.Popen(
- f'tar xzvf {tar_path} --directory {tmp_librispeech_dir}',
- shell=True).communicate()
+ f'tar xzvf {tar_path} --directory {tmp_librispeech_dir}', shell=True
+ ).communicate()
tokenizer_vocab_path = os.path.join(final_data_dir, 'spm_model.vocab')
if not os.path.exists(tokenizer_vocab_path):
librispeech_tokenizer.run(
- train=True,
- input_dir=extracted_data_dir,
- tokenizer_vocab_path=tokenizer_vocab_path)
+ train=True,
+ input_dir=extracted_data_dir,
+ tokenizer_vocab_path=tokenizer_vocab_path,
+ )
librispeech_preprocess.run(
- input_dir=extracted_data_dir,
- output_dir=final_data_dir,
- tokenizer_vocab_path=tokenizer_vocab_path)
+ input_dir=extracted_data_dir,
+ output_dir=final_data_dir,
+ tokenizer_vocab_path=tokenizer_vocab_path,
+ )
def download_mnist(data_dir):
@@ -691,12 +757,14 @@ def download_wmt(data_dir):
if ds_name == 'wmt17_translate/de-en:1.0.0':
ds = dataset_builder.as_dataset(split='train', shuffle_files=False)
ds = ds.map(
- functools.partial(normalize_feature_names, dataset_builder.info),
- num_parallel_calls=tf.data.AUTOTUNE)
+ functools.partial(normalize_feature_names, dataset_builder.info),
+ num_parallel_calls=tf.data.AUTOTUNE,
+ )
# Tokenize data.
vocab_path = os.path.join(data_dir, 'wmt_sentencepiece_model')
tokenizer.train_tokenizer(
- ds, vocab_path=vocab_path, vocab_size=32000, max_corpus_chars=10**7)
+ ds, vocab_path=vocab_path, vocab_size=32000, max_corpus_chars=10**7
+ )
def main(_):
@@ -715,10 +783,9 @@ def main(_):
if FLAGS.all or FLAGS.criteo1tb:
logging.info('Downloading criteo1tb...')
- download_criteo1tb(data_dir,
- tmp_dir,
- num_decompression_threads,
- FLAGS.interactive_deletion)
+ download_criteo1tb(
+ data_dir, tmp_dir, num_decompression_threads, FLAGS.interactive_deletion
+ )
if FLAGS.all or FLAGS.mnist:
logging.info('Downloading MNIST...')
@@ -730,19 +797,24 @@ def main(_):
knee_singlecoil_train_url = FLAGS.fastmri_knee_singlecoil_train_url
knee_singlecoil_val_url = FLAGS.fastmri_knee_singlecoil_val_url
knee_singlecoil_test_url = FLAGS.fastmri_knee_singlecoil_test_url
- if None in (knee_singlecoil_train_url,
- knee_singlecoil_val_url,
- knee_singlecoil_test_url):
+ if None in (
+ knee_singlecoil_train_url,
+ knee_singlecoil_val_url,
+ knee_singlecoil_test_url,
+ ):
raise ValueError(
- 'Must provide three --fastmri_knee_singlecoil_[train,val,test]_url '
- 'to download the FastMRI dataset.\nSign up for the URLs at '
- 'https://fastmri.med.nyu.edu/.')
+ 'Must provide three --fastmri_knee_singlecoil_[train,val,test]_url '
+ 'to download the FastMRI dataset.\nSign up for the URLs at '
+ 'https://fastmri.med.nyu.edu/.'
+ )
if not FLAGS.skip_download:
- download_fastmri(data_dir,
- knee_singlecoil_train_url,
- knee_singlecoil_val_url,
- knee_singlecoil_test_url)
+ download_fastmri(
+ data_dir,
+ knee_singlecoil_train_url,
+ knee_singlecoil_val_url,
+ knee_singlecoil_test_url,
+ )
logging.info('fastMRI download completed. Extracting...')
setup_fastmri(data_dir)
@@ -754,12 +826,13 @@ def main(_):
imagenet_val_url = FLAGS.imagenet_val_url
if imagenet_train_url is None or imagenet_val_url is None:
raise ValueError(
- 'Must provide both --imagenet_{train,val}_url to download the '
- 'ImageNet dataset. Sign up for the URLs at https://image-net.org/.')
+ 'Must provide both --imagenet_{train,val}_url to download the '
+ 'ImageNet dataset. Sign up for the URLs at https://image-net.org/.'
+ )
if FLAGS.framework is None:
raise ValueError(
- 'Please specify either jax or pytorch framework through framework '
- 'flag.')
+ 'Please specify either jax or pytorch framework through framework flag.'
+ )
if not FLAGS.skip_download:
logging.info('Downloading ImageNet...')
download_imagenet(data_dir, imagenet_train_url, imagenet_val_url)
diff --git a/datasets/librispeech_preprocess.py b/datasets/librispeech_preprocess.py
index a8c5cae1d..c419eb39b 100644
--- a/datasets/librispeech_preprocess.py
+++ b/datasets/librispeech_preprocess.py
@@ -28,17 +28,18 @@
# taken from TFDS page for librispeech dataset :
# https://www.tensorflow.org/datasets/catalog/librispeech
librispeech_example_counts = {
- 'train-clean-100': 28539,
- 'train-clean-360': 104014,
- 'train-other-500': 148688,
- 'test-clean': 2620, # 'test-other': 2939,
- 'dev-clean': 2703,
- 'dev-other': 2864,
+ 'train-clean-100': 28539,
+ 'train-clean-360': 104014,
+ 'train-other-500': 148688,
+ 'test-clean': 2620, # 'test-other': 2939,
+ 'dev-clean': 2703,
+ 'dev-other': 2864,
}
class Counter:
"""A threadsafe counter."""
+
lock = threading.Lock()
value = 0
@@ -56,10 +57,12 @@ def report_progress(count, total, start_time):
now = time.time()
size = 50
filled = int(round(size * count / float(total)))
- percent = round(100. * count / float(total), 1)
- bar = "-" * filled + "." * (size - filled)
- sys.stdout.write("[%s] %d%% (%d of %d) %.2f sample/sec\r" %
- (bar, percent, count, total, count / (now - start_time)))
+ percent = round(100.0 * count / float(total), 1)
+ bar = '-' * filled + '.' * (size - filled)
+ sys.stdout.write(
+ '[%s] %d%% (%d of %d) %.2f sample/sec\r'
+ % (bar, percent, count, total, count / (now - start_time))
+ )
sys.stdout.flush()
@@ -72,8 +75,10 @@ def process(index):
data_folder, speaker_folder, chapter_folder = index
utterance_ids = []
- trans_file = (f'{data_folder}/{speaker_folder}/{chapter_folder}/'
- f'{speaker_folder}-{chapter_folder}.trans.txt')
+ trans_file = (
+ f'{data_folder}/{speaker_folder}/{chapter_folder}/'
+ f'{speaker_folder}-{chapter_folder}.trans.txt'
+ )
if not exists(trans_file):
skipped.inc()
return utterance_ids
@@ -82,7 +87,8 @@ def process(index):
for l in f:
utt, trans = l.strip().split(' ', maxsplit=1)
audio_path = (
- f'{data_folder}/{speaker_folder}/{chapter_folder}/{utt}.flac')
+ f'{data_folder}/{speaker_folder}/{chapter_folder}/{utt}.flac'
+ )
if not os.path.isfile(audio_path):
skipped.inc()
@@ -105,9 +111,11 @@ def process(index):
np.save('{}/{}/{}_targets.npy'.format(out_folder, split, utt), targets)
finished.inc()
- report_progress(finished.val() + skipped.val(),
- librispeech_example_counts[split],
- start_time)
+ report_progress(
+ finished.val() + skipped.val(),
+ librispeech_example_counts[split],
+ start_time,
+ )
utterance_ids.append(utt)
return utterance_ids
@@ -126,10 +134,12 @@ def process(index):
end_time = time.time()
elapsed_time = end_time - start_time
- print(' \n time taken to preprocess split : ',
- split,
- ' = ',
- time.strftime("%H:%M:%S", time.gmtime(elapsed_time)))
+ print(
+ ' \n time taken to preprocess split : ',
+ split,
+ ' = ',
+ time.strftime('%H:%M:%S', time.gmtime(elapsed_time)),
+ )
final_count = finished.val() + skipped.val()
return pd.DataFrame(file_trans, columns=['id']), final_count
@@ -147,12 +157,12 @@ def run(input_dir, output_dir, tokenizer_vocab_path):
os.makedirs(output_dir, exist_ok=True)
subset_list = [
- 'train-clean-100',
- 'train-clean-360',
- 'train-other-500',
- 'dev-clean',
- 'dev-other',
- 'test-clean', # 'test-other',
+ 'train-clean-100',
+ 'train-clean-360',
+ 'train-other-500',
+ 'dev-clean',
+ 'dev-other',
+ 'test-clean', # 'test-other',
]
for subset in subset_list:
logging.info('Processing split = %s...', subset)
@@ -160,10 +170,14 @@ def run(input_dir, output_dir, tokenizer_vocab_path):
out_dir = os.path.join(output_dir, subset)
os.makedirs(out_dir, exist_ok=True)
example_ids, num_entries = preprocess_data(
- in_dir, output_dir, tokenizer, subset)
+ in_dir, output_dir, tokenizer, subset
+ )
if num_entries != librispeech_example_counts[subset]:
- raise ValueError('Preprocessed dataframe final count not equal to '
- 'expected count: {} vs expected {}'.format(
- num_entries, librispeech_example_counts[subset]))
+ raise ValueError(
+ 'Preprocessed dataframe final count not equal to '
+ 'expected count: {} vs expected {}'.format(
+ num_entries, librispeech_example_counts[subset]
+ )
+ )
example_ids.to_csv(os.path.join(output_dir, f'{subset}.csv'))
diff --git a/datasets/librispeech_tokenizer.py b/datasets/librispeech_tokenizer.py
index 2f559752a..5b9888cc2 100644
--- a/datasets/librispeech_tokenizer.py
+++ b/datasets/librispeech_tokenizer.py
@@ -24,7 +24,8 @@
def dump_chars_for_training(data_folder, splits, maxchars: int = int(1e7)):
char_count = 0
with tempfile.NamedTemporaryFile(
- delete=False, prefix='/tmp/ds_chars') as outfp:
+ delete=False, prefix='/tmp/ds_chars'
+ ) as outfp:
for split in splits:
data_folder = data_folder + '/' + split
for _, speaker_folder in enumerate(os.listdir(data_folder)):
@@ -32,8 +33,10 @@ def dump_chars_for_training(data_folder, splits, maxchars: int = int(1e7)):
break
for chapter_folder in os.listdir(f'{data_folder}/{speaker_folder}'):
- trans_file = (f'{data_folder}/{speaker_folder}/{chapter_folder}/'
- f'{speaker_folder}-{chapter_folder}.trans.txt')
+ trans_file = (
+ f'{data_folder}/{speaker_folder}/{chapter_folder}/'
+ f'{speaker_folder}-{chapter_folder}.trans.txt'
+ )
if not exists(trans_file):
logging.info('path does not exist -> %s', trans_file)
continue
@@ -50,13 +53,15 @@ def dump_chars_for_training(data_folder, splits, maxchars: int = int(1e7)):
return outfp
-def train_tokenizer(data_dir: str,
- splits,
- vocab_size: int = 1024,
- model_path: str = 'spm_model.vocab',
- maxchars: int = int(1e7),
- model_type: str = 'unigram',
- character_coverage: float = 1.0):
+def train_tokenizer(
+ data_dir: str,
+ splits,
+ vocab_size: int = 1024,
+ model_path: str = 'spm_model.vocab',
+ maxchars: int = int(1e7),
+ model_type: str = 'unigram',
+ character_coverage: float = 1.0,
+):
"""Train SentencePiece tokenizer from subset of tf dataset.
Args:
@@ -77,15 +82,18 @@ def train_tokenizer(data_dir: str,
charfile = dump_chars_for_training(data_dir, splits, maxchars=maxchars)
with tempfile.NamedTemporaryFile(
- delete=False, prefix='/tmp/sp_tmp') as model_fp:
+ delete=False, prefix='/tmp/sp_tmp'
+ ) as model_fp:
pass # we just want a prefix'd tmp-filename
- argstr = ' '.join([
+ argstr = ' '.join(
+ [
f'--input={charfile.name}',
f'--vocab_size={vocab_size}',
f'--character_coverage={character_coverage}',
f'--model_prefix={model_fp.name}',
f'--model_type={model_type}',
- ])
+ ]
+ )
spm.SentencePieceTrainer.Train(argstr)
copy_rename_path = abs_model_path + '.rntmp'
@@ -104,7 +112,8 @@ def load_tokenizer(model_filepath):
with gfile.GFile(model_filepath, 'rb') as model_fp:
sp_model = model_fp.read()
sp_tokenizer = tftxt.SentencepieceTokenizer(
- model=sp_model, add_bos=False, add_eos=True, reverse=False)
+ model=sp_model, add_bos=False, add_eos=True, reverse=False
+ )
return sp_tokenizer
@@ -123,8 +132,9 @@ def run(train, input_dir, tokenizer_vocab_path):
detokenized = tokenizer.detokenize(tokens).numpy().decode('utf-8')
logging.info('Original input = %s', test_input)
- logging.info('Output after after tokenizing and detokenizing = %s',
- detokenized)
+ logging.info(
+ 'Output after after tokenizing and detokenizing = %s', detokenized
+ )
if detokenized == test_input:
logging.info('Tokenizer working correctly!')
From c34af17f6e272be88e914ef297fdea65c62940d5 Mon Sep 17 00:00:00 2001
From: Frank Schneider
Date: Mon, 23 Jun 2025 10:21:09 +0200
Subject: [PATCH 107/123] Format algoperf/
---
algoperf/__init__.py | 2 +-
algoperf/checkpoint_utils.py | 195 +++--
algoperf/data_utils.py | 72 +-
algoperf/halton.py | 154 ++--
algoperf/init_utils.py | 4 +-
algoperf/interop_utils.py | 3 +-
algoperf/logger_utils.py | 149 ++--
algoperf/param_utils.py | 28 +-
algoperf/profiler.py | 81 +-
algoperf/pytorch_utils.py | 21 +-
algoperf/random_utils.py | 10 +-
algoperf/spec.py | 322 +++----
.../cifar/cifar_jax/input_pipeline.py | 143 ++--
algoperf/workloads/cifar/cifar_jax/models.py | 52 +-
.../workloads/cifar/cifar_jax/workload.py | 189 ++--
.../workloads/cifar/cifar_pytorch/models.py | 109 ++-
.../workloads/cifar/cifar_pytorch/workload.py | 165 ++--
algoperf/workloads/cifar/workload.py | 106 +--
.../criteo1tb/criteo1tb_jax/models.py | 133 +--
.../criteo1tb/criteo1tb_jax/workload.py | 120 +--
.../criteo1tb/criteo1tb_pytorch/models.py | 144 ++--
.../criteo1tb/criteo1tb_pytorch/workload.py | 174 ++--
.../workloads/criteo1tb/input_pipeline.py | 97 ++-
algoperf/workloads/criteo1tb/workload.py | 75 +-
.../workloads/fastmri/fastmri_jax/models.py | 80 +-
.../workloads/fastmri/fastmri_jax/ssim.py | 24 +-
.../workloads/fastmri/fastmri_jax/workload.py | 160 ++--
.../fastmri/fastmri_pytorch/models.py | 99 +--
.../workloads/fastmri/fastmri_pytorch/ssim.py | 21 +-
.../fastmri/fastmri_pytorch/workload.py | 206 ++---
algoperf/workloads/fastmri/input_pipeline.py | 120 +--
algoperf/workloads/fastmri/workload.py | 38 +-
.../imagenet_jax/custom_tf_addons.py | 548 ++++++------
.../imagenet_jax/input_pipeline.py | 358 ++++----
.../imagenet_resnet/imagenet_jax/models.py | 71 +-
.../imagenet_jax/randaugment.py | 228 +++--
.../imagenet_resnet/imagenet_jax/workload.py | 267 +++---
.../imagenet_pytorch/models.py | 219 ++---
.../imagenet_pytorch/randaugment.py | 124 +--
.../imagenet_pytorch/workload.py | 230 ++---
.../workloads/imagenet_resnet/imagenet_v2.py | 38 +-
.../workloads/imagenet_resnet/workload.py | 56 +-
.../imagenet_vit/imagenet_jax/models.py | 153 ++--
.../imagenet_vit/imagenet_jax/workload.py | 113 +--
.../imagenet_vit/imagenet_pytorch/models.py | 157 ++--
.../imagenet_vit/imagenet_pytorch/workload.py | 55 +-
algoperf/workloads/imagenet_vit/workload.py | 120 +--
.../librispeech_conformer/input_pipeline.py | 8 +-
.../librispeech_preprocessor.py | 512 ++++++-----
.../librispeech_jax/models.py | 484 ++++++-----
.../librispeech_jax/spectrum_augmenter.py | 81 +-
.../librispeech_jax/workload.py | 297 ++++---
.../librispeech_pytorch/models.py | 230 ++---
.../librispeech_pytorch/preprocessor.py | 642 +++++++-------
.../librispeech_pytorch/spectrum_augmenter.py | 97 ++-
.../librispeech_pytorch/workload.py | 239 +++---
.../librispeech_conformer/metrics.py | 39 +-
.../librispeech_conformer/workload.py | 11 +-
.../librispeech_jax/models.py | 309 ++++---
.../librispeech_jax/workload.py | 88 +-
.../librispeech_pytorch/models.py | 177 ++--
.../librispeech_pytorch/workload.py | 80 +-
.../workloads/mnist/mnist_jax/workload.py | 98 ++-
.../workloads/mnist/mnist_pytorch/workload.py | 163 ++--
algoperf/workloads/mnist/workload.py | 152 ++--
algoperf/workloads/ogbg/input_pipeline.py | 59 +-
algoperf/workloads/ogbg/metrics.py | 13 +-
algoperf/workloads/ogbg/ogbg_jax/models.py | 56 +-
algoperf/workloads/ogbg/ogbg_jax/workload.py | 108 +--
.../workloads/ogbg/ogbg_pytorch/models.py | 146 ++--
.../workloads/ogbg/ogbg_pytorch/workload.py | 184 ++--
algoperf/workloads/ogbg/workload.py | 110 +--
algoperf/workloads/utils.py | 12 +-
algoperf/workloads/wmt/bleu.py | 341 ++++----
algoperf/workloads/wmt/input_pipeline.py | 162 ++--
algoperf/workloads/wmt/tokenizer.py | 78 +-
algoperf/workloads/wmt/wmt_jax/decode.py | 160 ++--
algoperf/workloads/wmt/wmt_jax/models.py | 489 ++++++-----
algoperf/workloads/wmt/wmt_jax/workload.py | 246 +++---
algoperf/workloads/wmt/wmt_pytorch/decode.py | 210 +++--
algoperf/workloads/wmt/wmt_pytorch/models.py | 808 ++++++++++--------
.../workloads/wmt/wmt_pytorch/workload.py | 270 +++---
algoperf/workloads/wmt/workload.py | 116 +--
algoperf/workloads/workloads.py | 303 +++----
84 files changed, 7324 insertions(+), 6287 deletions(-)
diff --git a/algoperf/__init__.py b/algoperf/__init__.py
index 7d54f8290..5ecee05af 100644
--- a/algoperf/__init__.py
+++ b/algoperf/__init__.py
@@ -2,4 +2,4 @@
from ._version import version as __version__
-__all__ = ["__version__"]
+__all__ = ['__version__']
diff --git a/algoperf/checkpoint_utils.py b/algoperf/checkpoint_utils.py
index f4cb6c2db..577baaa34 100644
--- a/algoperf/checkpoint_utils.py
+++ b/algoperf/checkpoint_utils.py
@@ -20,24 +20,28 @@
from algoperf.pytorch_utils import pytorch_setup
_, _, DEVICE, _ = pytorch_setup()
-CheckpointReturn = Tuple[spec.OptimizerState,
- spec.ParameterContainer,
- spec.ModelAuxiliaryState,
- dict,
- list,
- int,
- int]
-
-
-def maybe_restore_checkpoint(framework: str,
- optimizer_state: spec.OptimizerState,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- train_state: dict,
- eval_results: list,
- global_step: int,
- preemption_count: int,
- checkpoint_dir: str) -> CheckpointReturn:
+CheckpointReturn = Tuple[
+ spec.OptimizerState,
+ spec.ParameterContainer,
+ spec.ModelAuxiliaryState,
+ dict,
+ list,
+ int,
+ int,
+]
+
+
+def maybe_restore_checkpoint(
+ framework: str,
+ optimizer_state: spec.OptimizerState,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ train_state: dict,
+ eval_results: list,
+ global_step: int,
+ preemption_count: int,
+ checkpoint_dir: str,
+) -> CheckpointReturn:
"""Optionally restores from a checkpoint.
The checkpoint logic is as follows: if there is a checkpoint in
@@ -69,20 +73,22 @@ def maybe_restore_checkpoint(framework: str,
uninitialized_global_step = -1
uninitialized_preemption_count = -1
checkpoint_state = {
- 'model_params': model_params,
- 'optimizer_state': opt_state,
- 'model_state': model_state,
- 'train_state': train_state,
- 'eval_results': None,
- 'global_step': uninitialized_global_step,
- 'preemption_count': uninitialized_preemption_count,
+ 'model_params': model_params,
+ 'optimizer_state': opt_state,
+ 'model_state': model_state,
+ 'train_state': train_state,
+ 'eval_results': None,
+ 'global_step': uninitialized_global_step,
+ 'preemption_count': uninitialized_preemption_count,
}
if framework == 'jax':
latest_ckpt = flax_checkpoints.restore_checkpoint(
- checkpoint_dir, target=checkpoint_state)
- save_path = os.path.join(checkpoint_dir,
- 'checkpoint_' + str(latest_ckpt['global_step']))
+ checkpoint_dir, target=checkpoint_state
+ )
+ save_path = os.path.join(
+ checkpoint_dir, 'checkpoint_' + str(latest_ckpt['global_step'])
+ )
else:
latest_ckpt = checkpoint_state
save_path = latest_checkpoint(checkpoint_dir)
@@ -94,55 +100,64 @@ def maybe_restore_checkpoint(framework: str,
found_checkpoint = latest_ckpt['global_step'] != uninitialized_global_step
if not found_checkpoint:
- return (optimizer_state,
- model_params,
- model_state,
- train_state,
- eval_results,
- global_step,
- preemption_count)
+ return (
+ optimizer_state,
+ model_params,
+ model_state,
+ train_state,
+ eval_results,
+ global_step,
+ preemption_count,
+ )
# If there's the latest checkpoint in the checkpoint_dir, restore from that.
if framework == 'jax':
checkpoint_state = replicate_checkpoint(
- latest_ckpt,
- pytree_keys=[
- 'optimizer_state',
- 'model_params',
- 'model_state',
- ])
- checkpoint_state['optimizer_state'] = (checkpoint_state['optimizer_state'],
- opt_update_fn)
+ latest_ckpt,
+ pytree_keys=[
+ 'optimizer_state',
+ 'model_params',
+ 'model_state',
+ ],
+ )
+ checkpoint_state['optimizer_state'] = (
+ checkpoint_state['optimizer_state'],
+ opt_update_fn,
+ )
checkpoint_state['eval_results'] = [
- (value, key) for key, value in latest_ckpt['eval_results'].items()
+ (value, key) for key, value in latest_ckpt['eval_results'].items()
]
else:
checkpoint_state = latest_ckpt
if isinstance(
- model_params,
- (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)):
+ model_params,
+ (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel),
+ ):
model_params = model_params.module
model_params.load_state_dict(checkpoint_state['model_params'])
checkpoint_state['model_params'] = model_params
for key in optimizer_state.keys():
optimizer_state[key].load_state_dict(
- checkpoint_state['optimizer_state'][key])
+ checkpoint_state['optimizer_state'][key]
+ )
checkpoint_state['optimizer_state'][key] = optimizer_state[key]
logging.info(f'Loaded checkpoint from {save_path}.')
- return (checkpoint_state['optimizer_state'],
- checkpoint_state['model_params'],
- checkpoint_state['model_state'],
- checkpoint_state['train_state'],
- list(checkpoint_state['eval_results']),
- checkpoint_state['global_step'],
- checkpoint_state['preemption_count'] + 1)
-
-
-def replicate_checkpoint(latest: dict,
- pytree_keys: Sequence[str],
- replicate: bool = True) -> dict:
+ return (
+ checkpoint_state['optimizer_state'],
+ checkpoint_state['model_params'],
+ checkpoint_state['model_state'],
+ checkpoint_state['train_state'],
+ list(checkpoint_state['eval_results']),
+ checkpoint_state['global_step'],
+ checkpoint_state['preemption_count'] + 1,
+ )
+
+
+def replicate_checkpoint(
+ latest: dict, pytree_keys: Sequence[str], replicate: bool = True
+) -> dict:
"""Restores from the provided checkpoint.
Args:
@@ -163,16 +178,18 @@ def replicate_checkpoint(latest: dict,
return pytree
-def save_checkpoint(framework: str,
- optimizer_state: spec.OptimizerState,
- model_params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- train_state: dict,
- eval_results: list,
- global_step: int,
- preemption_count: int,
- checkpoint_dir: str,
- save_intermediate_checkpoints: bool) -> None:
+def save_checkpoint(
+ framework: str,
+ optimizer_state: spec.OptimizerState,
+ model_params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ train_state: dict,
+ eval_results: list,
+ global_step: int,
+ preemption_count: int,
+ checkpoint_dir: str,
+ save_intermediate_checkpoints: bool,
+) -> None:
"""Save the checkpoint in `checkpoint_dir`.
Args:
@@ -199,8 +216,9 @@ def save_checkpoint(framework: str,
model_state = jax.device_get(jax_utils.unreplicate(model_state))
else:
if isinstance(
- model_params,
- (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)):
+ model_params,
+ (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel),
+ ):
model_params = model_params.module
model_params = model_params.state_dict()
optimizer_state_dict = {}
@@ -209,33 +227,36 @@ def save_checkpoint(framework: str,
optimizer_state_dict[key] = optimizer_state[key].state_dict()
else:
logging.warning(
- f'The optimizer state for key {key} is not saved, because '
- f'{type(optimizer_state[key])} has not implemented a state_dict() '
- 'method.')
+ f'The optimizer state for key {key} is not saved, because '
+ f'{type(optimizer_state[key])} has not implemented a state_dict() '
+ 'method.'
+ )
opt_state = optimizer_state_dict
checkpoint_state = {
- 'model_params': model_params,
- 'optimizer_state': opt_state,
- 'model_state': model_state,
- 'train_state': train_state,
- 'eval_results': tuple(eval_results),
- 'global_step': global_step,
- 'preemption_count': preemption_count,
+ 'model_params': model_params,
+ 'optimizer_state': opt_state,
+ 'model_state': model_state,
+ 'train_state': train_state,
+ 'eval_results': tuple(eval_results),
+ 'global_step': global_step,
+ 'preemption_count': preemption_count,
}
save_path = os.path.join(checkpoint_dir, f'checkpoint_{global_step}')
if framework == 'jax':
flax_checkpoints.save_checkpoint(
- checkpoint_dir,
- target=checkpoint_state,
- step=global_step,
- overwrite=True,
- keep=np.inf if save_intermediate_checkpoints else 1)
+ checkpoint_dir,
+ target=checkpoint_state,
+ step=global_step,
+ overwrite=True,
+ keep=np.inf if save_intermediate_checkpoints else 1,
+ )
else:
if not save_intermediate_checkpoints:
checkpoint_files = gfile.glob(
- os.path.join(checkpoint_dir, 'checkpoint_*'))
+ os.path.join(checkpoint_dir, 'checkpoint_*')
+ )
for path in checkpoint_files:
logging.info('Removing checkpoint at %s', path)
gfile.rmtree(path)
diff --git a/algoperf/data_utils.py b/algoperf/data_utils.py
index 37d1bd20f..919ccd125 100644
--- a/algoperf/data_utils.py
+++ b/algoperf/data_utils.py
@@ -15,9 +15,10 @@
def shard_and_maybe_pad_np(
- batch: Dict[str, spec.Tensor],
- padding_value: int = 0,
- global_batch_size: Optional[int] = None) -> Dict[str, spec.Tensor]:
+ batch: Dict[str, spec.Tensor],
+ padding_value: int = 0,
+ global_batch_size: Optional[int] = None,
+) -> Dict[str, spec.Tensor]:
"""Prepare tf data for JAX or PyTorch DDP.
Convert an input batch from tf Tensors to numpy arrays, pad it with
@@ -26,11 +27,13 @@ def shard_and_maybe_pad_np(
"""
local_device_count = max(torch.cuda.device_count(), jax.local_device_count())
inputs = batch['inputs']
- current_batch_size = inputs[0].shape[0] if isinstance(
- inputs, tuple) else inputs.shape[0]
+ current_batch_size = (
+ inputs[0].shape[0] if isinstance(inputs, tuple) else inputs.shape[0]
+ )
if global_batch_size is not None:
- assert global_batch_size >= current_batch_size, \
- 'global_batch_size must be larger than or equal to current_batch_size.'
+ assert global_batch_size >= current_batch_size, (
+ 'global_batch_size must be larger than or equal to current_batch_size.'
+ )
# Always pad to global_batch_size if it is provided.
pad_to_global_batch_size = global_batch_size > current_batch_size
else:
@@ -43,7 +46,8 @@ def shard_and_maybe_pad_np(
pad_size = local_device_count - remainder_size
targets = batch['targets']
targets_shape = tuple(
- targets[0].shape if isinstance(targets, tuple) else targets.shape)
+ targets[0].shape if isinstance(targets, tuple) else targets.shape
+ )
# We need a 2d mask for WMT.
mask_shape = targets_shape if len(targets_shape) < 3 else targets_shape[0]
# Get weights from batch if there are any.
@@ -68,9 +72,9 @@ def _prepare(x):
return jax.tree.map(_prepare, batch)
-def pad(tensor: np.ndarray,
- pad_size: int,
- padding_value: int = 0) -> np.ndarray:
+def pad(
+ tensor: np.ndarray, pad_size: int, padding_value: int = 0
+) -> np.ndarray:
if tensor.ndim > 1:
pad_size = (pad_size, *tensor.shape[1:])
padding = np.full(pad_size, padding_value, dtype=tensor.dtype)
@@ -78,8 +82,9 @@ def pad(tensor: np.ndarray,
return padded_tensor
-def mixup_pytorch(batch: Tuple[spec.Tensor, spec.Tensor],
- alpha: float = 0.2) -> Tuple[spec.Tensor, spec.Tensor]:
+def mixup_pytorch(
+ batch: Tuple[spec.Tensor, spec.Tensor], alpha: float = 0.2
+) -> Tuple[spec.Tensor, spec.Tensor]:
inputs, targets = batch
# Transform to one-hot targets.
targets = F.one_hot(targets, num_classes=1000)
@@ -144,12 +149,14 @@ class DistributedEvalSampler(Sampler):
... train(loader)
"""
- def __init__(self,
- dataset: torch.utils.data.Dataset,
- num_replicas: Optional[int] = None,
- rank: Optional[int] = None,
- shuffle: bool = False,
- seed: int = 0) -> None:
+ def __init__(
+ self,
+ dataset: torch.utils.data.Dataset,
+ num_replicas: Optional[int] = None,
+ rank: Optional[int] = None,
+ shuffle: bool = False,
+ seed: int = 0,
+ ) -> None:
if num_replicas is None:
if not dist.is_available():
raise RuntimeError('Requires distributed package to be available.')
@@ -165,7 +172,7 @@ def __init__(self,
# true value without extra samples
self.total_size = len(self.dataset)
indices = list(range(self.total_size))
- indices = indices[self.rank:self.total_size:self.num_replicas]
+ indices = indices[self.rank : self.total_size : self.num_replicas]
# true value without extra samples
self.num_samples = len(indices)
@@ -182,7 +189,7 @@ def __iter__(self) -> Iterable[int]:
indices = list(range(len(self.dataset)))
# Subsample.
- indices = indices[self.rank:self.total_size:self.num_replicas]
+ indices = indices[self.rank : self.total_size : self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
@@ -203,11 +210,13 @@ def set_epoch(self, epoch: int) -> None:
# Modified from github.com/pytorch/pytorch/issues/23900#issuecomment-518858050.
-def cycle(iterable: Iterable,
- keys: Tuple[str, ...] = ('inputs', 'targets'),
- custom_sampler: bool = False,
- use_mixup: bool = False,
- mixup_alpha: float = 0.2) -> Iterable:
+def cycle(
+ iterable: Iterable,
+ keys: Tuple[str, ...] = ('inputs', 'targets'),
+ custom_sampler: bool = False,
+ use_mixup: bool = False,
+ mixup_alpha: float = 0.2,
+) -> Iterable:
iterator = iter(iterable)
epoch = 0
while True:
@@ -229,11 +238,9 @@ def cycle(iterable: Iterable,
# github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/Classification/
# ConvNets/image_classification/dataloaders.py
class PrefetchedWrapper:
-
- def __init__(self,
- dataloader: DataLoader,
- device: torch.device,
- start_epoch: int = 0) -> None:
+ def __init__(
+ self, dataloader: DataLoader, device: torch.device, start_epoch: int = 0
+ ) -> None:
self.dataloader = dataloader
self.epoch = start_epoch
self.device = device
@@ -254,7 +261,8 @@ def prefetched_loader(self) -> Iterable[Tuple[spec.Tensor, spec.Tensor]]:
for next_inputs, next_targets in self.dataloader:
with torch.cuda.stream(stream):
next_inputs = next_inputs.to(
- self.device, dtype=torch.float, non_blocking=True)
+ self.device, dtype=torch.float, non_blocking=True
+ )
next_targets = next_targets.to(self.device, non_blocking=True)
if not first:
diff --git a/algoperf/halton.py b/algoperf/halton.py
index 1f36b07bf..08c5466f1 100644
--- a/algoperf/halton.py
+++ b/algoperf/halton.py
@@ -36,10 +36,12 @@ def _is_prime(n: int) -> bool:
return all(n % i != 0 for i in range(2, int(n**0.5) + 1)) and n != 2
-def _generate_dim(num_samples: int,
- base: int,
- per_dim_shift: bool,
- shuffled_seed_sequence: List[int]) -> List[float]:
+def _generate_dim(
+ num_samples: int,
+ base: int,
+ per_dim_shift: bool,
+ shuffled_seed_sequence: List[int],
+) -> List[float]:
"""Generate `num_samples` from a Van der Corput sequence with base `base`.
Args:
@@ -59,8 +61,9 @@ def _generate_dim(num_samples: int,
ValueError: if `base` is negative or not prime.
"""
if base < 0 or not _is_prime(base):
- raise ValueError('Each Van der Corput sequence requires a prime `base`, '
- f'received {base}.')
+ raise ValueError(
+ f'Each Van der Corput sequence requires a prime `base`, received {base}.'
+ )
rng = random.RandomState(base)
if shuffled_seed_sequence is None:
@@ -76,7 +79,7 @@ def _generate_dim(num_samples: int,
dim_sequence = []
for i in range(1, num_samples + 1):
- num = 0.
+ num = 0.0
denominator = base
while i:
num += shuffled_seed_sequence[i % base] / denominator
@@ -91,13 +94,15 @@ def _generate_dim(num_samples: int,
Matrix = List[List[int]]
-def generate_sequence(num_samples: int,
- num_dims: int,
- skip: int = 100,
- per_dim_shift: bool = True,
- shuffle_sequence: bool = True,
- primes: Sequence[int] = None,
- shuffled_seed_sequence: Matrix = None) -> Matrix:
+def generate_sequence(
+ num_samples: int,
+ num_dims: int,
+ skip: int = 100,
+ per_dim_shift: bool = True,
+ shuffle_sequence: bool = True,
+ primes: Sequence[int] = None,
+ shuffled_seed_sequence: Matrix = None,
+) -> Matrix:
"""Generate `num_samples` from a Halton sequence of dimension `num_dims`.
Each dimension is generated independently from a shuffled Van der Corput
@@ -140,25 +145,29 @@ def generate_sequence(num_samples: int,
if primes is not None and len(primes) != num_dims:
raise ValueError(
- 'If passing in a sequence of primes it must be the same length as '
- f'num_dims={num_dims}, received {primes} (len {len(primes)}).')
+ 'If passing in a sequence of primes it must be the same length as '
+ f'num_dims={num_dims}, received {primes} (len {len(primes)}).'
+ )
if shuffled_seed_sequence is not None:
if len(shuffled_seed_sequence) != num_dims:
raise ValueError(
- 'If passing in `shuffled_seed_sequence` it must be the same length '
- f'as num_dims={num_dims}, received {shuffled_seed_sequence} '
- f'(len {len(shuffled_seed_sequence)}).')
+ 'If passing in `shuffled_seed_sequence` it must be the same length '
+ f'as num_dims={num_dims}, received {shuffled_seed_sequence} '
+ f'(len {len(shuffled_seed_sequence)}).'
+ )
for d in range(num_dims):
if len(shuffled_seed_sequence[d]) != primes[d]:
raise ValueError(
- 'If passing in `shuffled_seed_sequence` it must have element `{d}` '
- 'be a sequence of length `primes[{d}]`={expected}, received '
- '{actual} (len {length})'.format(
- d=d,
- expected=primes[d],
- actual=shuffled_seed_sequence[d],
- length=shuffled_seed_sequence[d]))
+ 'If passing in `shuffled_seed_sequence` it must have element `{d}` '
+ 'be a sequence of length `primes[{d}]`={expected}, received '
+ '{actual} (len {length})'.format(
+ d=d,
+ expected=primes[d],
+ actual=shuffled_seed_sequence[d],
+ length=shuffled_seed_sequence[d],
+ )
+ )
if primes is None:
primes = []
@@ -166,7 +175,7 @@ def generate_sequence(num_samples: int,
while len(primes) < num_dims + 1:
primes = generate_primes(1000 * prime_attempts)
prime_attempts += 1
- primes = primes[-num_dims - 1:-1]
+ primes = primes[-num_dims - 1 : -1]
# Skip the first `skip` points in the sequence because they can have unwanted
# correlations.
@@ -179,10 +188,11 @@ def generate_sequence(num_samples: int,
else:
dim_shuffled_seed_sequence = shuffled_seed_sequence[d]
dim_sequence = _generate_dim(
- num_samples=num_samples,
- base=primes[d],
- shuffled_seed_sequence=dim_shuffled_seed_sequence,
- per_dim_shift=per_dim_shift)
+ num_samples=num_samples,
+ base=primes[d],
+ shuffled_seed_sequence=dim_shuffled_seed_sequence,
+ per_dim_shift=per_dim_shift,
+ )
dim_sequence = dim_sequence[skip:]
halton_sequence.append(dim_sequence)
@@ -195,29 +205,29 @@ def generate_sequence(num_samples: int,
return halton_sequence
-def _generate_double_point(name: str,
- min_val: float,
- max_val: float,
- scaling: str,
- halton_point: float) -> Tuple[str, float]:
+def _generate_double_point(
+ name: str, min_val: float, max_val: float, scaling: str, halton_point: float
+) -> Tuple[str, float]:
"""Generate a float hyperparameter value from a Halton sequence point."""
if scaling not in ['linear', 'log']:
raise ValueError(
- 'Only log or linear scaling is supported for floating point '
- f'parameters. Received {scaling}.')
+ 'Only log or linear scaling is supported for floating point '
+ f'parameters. Received {scaling}.'
+ )
if scaling == 'log':
# To transform from [0, 1] to [min_val, max_val] on a log scale we do:
# min_val * exp(x * log(max_val / min_val)).
- rescaled_value = (
- min_val * math.exp(halton_point * math.log(max_val / min_val)))
+ rescaled_value = min_val * math.exp(
+ halton_point * math.log(max_val / min_val)
+ )
else:
rescaled_value = halton_point * (max_val - min_val) + min_val
return name, rescaled_value
-def _generate_discrete_point(name: str,
- feasible_points: Sequence[Any],
- halton_point: float) -> Any:
+def _generate_discrete_point(
+ name: str, feasible_points: Sequence[Any], halton_point: float
+) -> Any:
"""Generate a discrete hyperparameter value from a Halton sequence point."""
index = int(math.floor(halton_point * len(feasible_points)))
return name, feasible_points[index]
@@ -236,27 +246,23 @@ def interval(start: int, end: int) -> Tuple[int, int]:
def loguniform(name: str, range_endpoints: Tuple[int, int]) -> _GeneratorFn:
min_val, max_val = range_endpoints
- return functools.partial(_generate_double_point,
- name,
- min_val,
- max_val,
- 'log')
+ return functools.partial(
+ _generate_double_point, name, min_val, max_val, 'log'
+ )
def uniform(
- name: str, search_points: Union[_DiscretePoints,
- Tuple[int, int]]) -> _GeneratorFn:
+ name: str, search_points: Union[_DiscretePoints, Tuple[int, int]]
+) -> _GeneratorFn:
if isinstance(search_points, _DiscretePoints):
- return functools.partial(_generate_discrete_point,
- name,
- search_points.feasible_points)
+ return functools.partial(
+ _generate_discrete_point, name, search_points.feasible_points
+ )
min_val, max_val = search_points
- return functools.partial(_generate_double_point,
- name,
- min_val,
- max_val,
- 'linear')
+ return functools.partial(
+ _generate_double_point, name, min_val, max_val, 'linear'
+ )
def product(sweeps: Sequence[_SweepSequence]) -> _SweepSequence:
@@ -277,9 +283,10 @@ def sweep(name, feasible_points: Sequence[Any]) -> _SweepSequence:
return [{name: x} for x in feasible_points.feasible_points]
-def zipit(generator_fns_or_sweeps: Sequence[Union[_GeneratorFn,
- _SweepSequence]],
- length: int) -> _SweepSequence:
+def zipit(
+ generator_fns_or_sweeps: Sequence[Union[_GeneratorFn, _SweepSequence]],
+ length: int,
+) -> _SweepSequence:
"""Zip together a list of hyperparameter generators.
Args:
@@ -302,7 +309,8 @@ def zipit(generator_fns_or_sweeps: Sequence[Union[_GeneratorFn,
hyperparameter name from generator_fns_or_sweeps.
"""
halton_sequence = generate_sequence(
- num_samples=length, num_dims=len(generator_fns_or_sweeps))
+ num_samples=length, num_dims=len(generator_fns_or_sweeps)
+ )
# A List[Dict] of hyperparameter names to sweep values.
hyperparameter_sweep = []
for trial_index in range(length):
@@ -326,8 +334,9 @@ def zipit(generator_fns_or_sweeps: Sequence[Union[_GeneratorFn,
_ListSearchSpace = List[Dict[str, Union[str, float, Sequence]]]
-def generate_search(search_space: Union[_DictSearchSpace, _ListSearchSpace],
- num_trials: int) -> List[collections.namedtuple]:
+def generate_search(
+ search_space: Union[_DictSearchSpace, _ListSearchSpace], num_trials: int
+) -> List[collections.namedtuple]:
"""Generate a random search with the given bounds and scaling.
Args:linear
@@ -352,8 +361,9 @@ def generate_search(search_space: Union[_DictSearchSpace, _ListSearchSpace],
else:
raise AttributeError('tuning_search_space should either be a dict or list.')
- named_tuple_class = collections.namedtuple('Hyperparameters',
- all_hyperparameter_names)
+ named_tuple_class = collections.namedtuple(
+ 'Hyperparameters', all_hyperparameter_names
+ )
if isinstance(search_space, dict):
hyperparameter_generators = []
@@ -367,16 +377,18 @@ def generate_search(search_space: Union[_DictSearchSpace, _ListSearchSpace],
generator_fn = uniform(name, interval(space['min'], space['max']))
hyperparameter_generators.append(generator_fn)
return [
- named_tuple_class(**p)
- for p in zipit(hyperparameter_generators, num_trials)
+ named_tuple_class(**p)
+ for p in zipit(hyperparameter_generators, num_trials)
]
else:
hyperparameters = []
updated_num_trials = min(num_trials, len(search_space))
if num_trials != len(search_space):
- logging.info(f'--num_tuning_trials was set to {num_trials}, but '
- f'{len(search_space)} trial(s) found in the JSON file. '
- f'Updating --num_tuning_trials to {updated_num_trials}.')
+ logging.info(
+ f'--num_tuning_trials was set to {num_trials}, but '
+ f'{len(search_space)} trial(s) found in the JSON file. '
+ f'Updating --num_tuning_trials to {updated_num_trials}.'
+ )
for trial in search_space:
hyperparameters.append(named_tuple_class(**trial))
return hyperparameters[:updated_num_trials]
diff --git a/algoperf/init_utils.py b/algoperf/init_utils.py
index 185480cc7..c66a0be20 100644
--- a/algoperf/init_utils.py
+++ b/algoperf/init_utils.py
@@ -12,7 +12,7 @@
def pytorch_default_init(module: nn.Module) -> None:
# Perform lecun_normal initialization.
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
- std = math.sqrt(1. / fan_in) / .87962566103423978
+ std = math.sqrt(1.0 / fan_in) / 0.87962566103423978
nn.init.trunc_normal_(module.weight, std=std, a=-2 * std, b=2 * std)
if module.bias is not None:
- nn.init.constant_(module.bias, 0.)
+ nn.init.constant_(module.bias, 0.0)
diff --git a/algoperf/interop_utils.py b/algoperf/interop_utils.py
index 0c6535d7a..c30d0cf3b 100644
--- a/algoperf/interop_utils.py
+++ b/algoperf/interop_utils.py
@@ -6,7 +6,8 @@
def jax_to_pytorch(x: spec.Tensor, take_ownership: bool = False) -> spec.Tensor:
return torch.utils.dlpack.from_dlpack(
- jax.dlpack.to_dlpack(x, take_ownership=take_ownership))
+ jax.dlpack.to_dlpack(x, take_ownership=take_ownership)
+ )
def pytorch_to_jax(x: torch.Tensor) -> spec.Tensor:
diff --git a/algoperf/logger_utils.py b/algoperf/logger_utils.py
index c988956dc..3c4898142 100644
--- a/algoperf/logger_utils.py
+++ b/algoperf/logger_utils.py
@@ -37,12 +37,12 @@ def makedir(dir_name: str, exist_ok: bool = True) -> None:
def get_log_dir(
- experiment_dir: str,
- workload: spec.Workload,
- framework: str,
- experiment_name: str,
- resume_last_run: bool,
- overwrite: bool,
+ experiment_dir: str,
+ workload: spec.Workload,
+ framework: str,
+ experiment_name: str,
+ resume_last_run: bool,
+ overwrite: bool,
) -> Optional[str]:
# Construct path to experiment workload directory.
experiment_dir = os.path.expanduser(experiment_dir)
@@ -50,26 +50,29 @@ def get_log_dir(
if experiment_name is None:
experiment_path = os.path.join(experiment_dir, workload_dir_name)
else:
- experiment_path = os.path.join(experiment_dir,
- experiment_name,
- workload_dir_name)
+ experiment_path = os.path.join(
+ experiment_dir, experiment_name, workload_dir_name
+ )
if os.path.exists(experiment_path):
if overwrite:
logging.info(
- f'Removing existing experiment directory {experiment_path} because '
- '--overwrite was set.')
+ f'Removing existing experiment directory {experiment_path} because '
+ '--overwrite was set.'
+ )
if RANK == 0:
shutil.rmtree(experiment_path)
elif resume_last_run:
logging.info(
- f'Resuming from experiment directory {experiment_path} because '
- '--resume_last_run was set.')
+ f'Resuming from experiment directory {experiment_path} because '
+ '--resume_last_run was set.'
+ )
else:
if RANK == 0:
resume = input(
- 'Found existing experiment dir with the same name: {}. Do you wish '
- 'to resume training from this dir? [y/N]:'.format(experiment_path))
+ 'Found existing experiment dir with the same name: {}. Do you wish '
+ 'to resume training from this dir? [y/N]:'.format(experiment_path)
+ )
if resume.lower() != 'y':
sys.exit()
@@ -83,16 +86,18 @@ def get_log_dir(
return experiment_path
-def write_hparams(hparams: spec.Hyperparameters,
- tuning_dir: str) -> spec.Hyperparameters:
+def write_hparams(
+ hparams: spec.Hyperparameters, tuning_dir: str
+) -> spec.Hyperparameters:
hparams_file_name = os.path.join(tuning_dir, 'hparams.json')
if os.path.exists(hparams_file_name):
# If hparams.json already exist, use the previously saved hyperparameters.
logging.info('Loading hparams from %s.', hparams_file_name)
with open(hparams_file_name, 'r') as f:
hparams_dict = json.load(f)
- hparams = collections.namedtuple('Hyperparameters',
- hparams_dict)(**hparams_dict)
+ hparams = collections.namedtuple('Hyperparameters', hparams_dict)(
+ **hparams_dict
+ )
else:
logging.info('Saving hparams to %s.', hparams_file_name)
if RANK == 0:
@@ -108,8 +113,8 @@ def write_json(name: str, log_dict: Dict, indent: int = 2) -> None:
def write_to_csv(
- metrics: Dict,
- csv_path: str,
+ metrics: Dict,
+ csv_path: str,
) -> None:
try:
with open(csv_path, 'r') as csv_file:
@@ -118,8 +123,10 @@ def write_to_csv(
except (pd.errors.EmptyDataError, FileNotFoundError) as e:
measurements = pd.DataFrame([metrics], columns=sorted(metrics.keys()))
if isinstance(e, pd.errors.EmptyDataError):
- logging.info('Measurements file is empty. Create a new one, starting '
- 'with metrics from this step.')
+ logging.info(
+ 'Measurements file is empty. Create a new one, starting '
+ 'with metrics from this step.'
+ )
with open(csv_path, 'w') as csv_file:
measurements.to_csv(csv_file, index=False)
return
@@ -130,7 +137,8 @@ def _get_utilization() -> Dict:
# CPU
util_data['cpu.util.avg_percent_since_last'] = psutil.cpu_percent(
- interval=None) # non-blocking (cpu util percentage since last call)
+ interval=None
+ ) # non-blocking (cpu util percentage since last call)
util_data['cpu.freq.current'] = psutil.cpu_freq().current
# Memory
@@ -208,11 +216,14 @@ def _get_system_hardware_info() -> Dict:
def _get_system_software_info() -> Dict:
system_software_info = {}
- system_software_info['os_platform'] = \
- platform.platform() # Ex. 'Linux-5.4.48-x86_64-with-glibc2.29'
- system_software_info['python_version'] = platform.python_version(
+ system_software_info['os_platform'] = (
+ platform.platform()
+ ) # Ex. 'Linux-5.4.48-x86_64-with-glibc2.29'
+ system_software_info['python_version'] = (
+ platform.python_version()
) # Ex. '3.11.10'
- system_software_info['python_compiler'] = platform.python_compiler(
+ system_software_info['python_compiler'] = (
+ platform.python_compiler()
) # Ex. 'GCC 9.3.0'
# Note: do not store hostname as that may be sensitive
@@ -228,19 +239,28 @@ def _get_system_software_info() -> Dict:
def _get_git_commit_hash() -> str:
- return subprocess.check_output(['git', 'rev-parse',
- 'HEAD']).decode('ascii').strip()
+ return (
+ subprocess.check_output(['git', 'rev-parse', 'HEAD'])
+ .decode('ascii')
+ .strip()
+ )
def _get_git_branch() -> str:
- return subprocess.check_output(['git', 'rev-parse', '--abbrev-ref',
- 'HEAD']).decode('ascii').strip()
+ return (
+ subprocess.check_output(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
+ .decode('ascii')
+ .strip()
+ )
def _get_cpu_model_name() -> str:
output = subprocess.check_output(['lscpu']).decode('ascii').strip()
- return re.findall(r'(?=Model name:\s{1,}).*',
- output)[0].split('Model name:')[1].strip()
+ return (
+ re.findall(r'(?=Model name:\s{1,}).*', output)[0]
+ .split('Model name:')[1]
+ .strip()
+ )
def _is_primitive_type(item: Any) -> bool:
@@ -252,23 +272,25 @@ def _get_workload_properties(workload: spec.Workload) -> Dict:
workload_properties = {}
skip_list = ['param_shapes', 'model_params_types']
keys = [
- key for key in dir(workload)
- if not key.startswith('_') and key not in skip_list
+ key
+ for key in dir(workload)
+ if not key.startswith('_') and key not in skip_list
]
for key in keys:
try:
attr = getattr(workload, key)
except: # pylint: disable=bare-except
logging.info(
- f'Unable to record workload.{key} information. Continuing without it.'
+ f'Unable to record workload.{key} information. Continuing without it.'
)
if _is_primitive_type(attr):
workload_properties[f'workload.{key}'] = attr
return workload_properties
-def get_meta_data(workload: spec.Workload,
- rng_seed: Optional[int] = None) -> Dict:
+def get_meta_data(
+ workload: spec.Workload, rng_seed: Optional[int] = None
+) -> Dict:
meta_data = {}
workload_properties = _get_workload_properties(workload)
meta_data.update(workload_properties)
@@ -290,12 +312,14 @@ class MetricLogger(object):
the wrong time.
"""
- def __init__(self,
- csv_path: str,
- eval_csv_path: str,
- events_dir: Optional[str] = None,
- configs: Optional[flags.FLAGS] = None,
- hyperparameters: Optional[spec.Hyperparameters] = None) -> None:
+ def __init__(
+ self,
+ csv_path: str,
+ eval_csv_path: str,
+ events_dir: Optional[str] = None,
+ configs: Optional[flags.FLAGS] = None,
+ hyperparameters: Optional[spec.Hyperparameters] = None,
+ ) -> None:
self._measurements = {}
self._csv_path = csv_path
self._eval_csv_path = eval_csv_path
@@ -305,15 +329,18 @@ def __init__(self,
self._tb_metric_writer = metric_writers.create_default_writer(events_dir)
if wandb is not None and self.use_wandb:
wandb.init(
- dir=events_dir, tags=[flags.FLAGS.workload, flags.FLAGS.framework])
+ dir=events_dir, tags=[flags.FLAGS.workload, flags.FLAGS.framework]
+ )
wandb.config.update(configs)
wandb.config.update(hyperparameters._asdict())
- def append_scalar_metrics(self,
- metrics: Dict,
- global_step: int,
- preemption_count: Optional[int] = None,
- is_eval: bool = False) -> None:
+ def append_scalar_metrics(
+ self,
+ metrics: Dict,
+ global_step: int,
+ preemption_count: Optional[int] = None,
+ is_eval: bool = False,
+ ) -> None:
metrics['global_step'] = global_step
if preemption_count is not None:
metrics['preemption_count'] = preemption_count
@@ -324,7 +351,8 @@ def append_scalar_metrics(self,
if self._tb_metric_writer:
self._tb_metric_writer.write_scalars(
- step=int(metrics['global_step']), scalars=metrics)
+ step=int(metrics['global_step']), scalars=metrics
+ )
self._tb_metric_writer.flush()
if wandb is not None and self.use_wandb:
@@ -335,15 +363,16 @@ def finish(self) -> None:
wandb.finish()
-def set_up_loggers(train_dir: str,
- configs: flags.FLAGS,
- hyperparameters: spec.Hyperparameters) -> MetricLogger:
+def set_up_loggers(
+ train_dir: str, configs: flags.FLAGS, hyperparameters: spec.Hyperparameters
+) -> MetricLogger:
csv_path = os.path.join(train_dir, 'measurements.csv')
eval_csv_path = os.path.join(train_dir, 'eval_measurements.csv')
metrics_logger = MetricLogger(
- csv_path=csv_path,
- eval_csv_path=eval_csv_path,
- events_dir=train_dir,
- configs=configs,
- hyperparameters=hyperparameters)
+ csv_path=csv_path,
+ eval_csv_path=eval_csv_path,
+ events_dir=train_dir,
+ configs=configs,
+ hyperparameters=hyperparameters,
+ )
return metrics_logger
diff --git a/algoperf/param_utils.py b/algoperf/param_utils.py
index 05d882404..908ef0f27 100644
--- a/algoperf/param_utils.py
+++ b/algoperf/param_utils.py
@@ -14,7 +14,8 @@ def pytorch_param_shapes(model: nn.Module) -> Dict[str, spec.ShapeTuple]:
def pytorch_param_types(
- param_shapes: Dict[str, spec.ShapeTuple]) -> Dict[str, spec.ParameterType]:
+ param_shapes: Dict[str, spec.ShapeTuple],
+) -> Dict[str, spec.ParameterType]:
param_types = {}
for name in param_shapes.keys():
if 'bn' in name:
@@ -65,18 +66,21 @@ def pytorch_param_types(
def jax_param_shapes(
- params: spec.ParameterContainer) -> spec.ParameterShapeTree:
+ params: spec.ParameterContainer,
+) -> spec.ParameterShapeTree:
return jax.tree.map(lambda x: spec.ShapeTuple(x.shape), params)
-def jax_param_types(param_shapes: spec.ParameterShapeTree,
- parent_name: str = '') -> Dict[str, spec.ParameterType]:
+def jax_param_types(
+ param_shapes: spec.ParameterShapeTree, parent_name: str = ''
+) -> Dict[str, spec.ParameterType]:
param_types = {}
for name, value in param_shapes.items():
name = name.lower()
if isinstance(value, dict) or isinstance(value, flax.core.FrozenDict):
param_types[name] = jax_param_types(
- value, parent_name=parent_name + '/' + name)
+ value, parent_name=parent_name + '/' + name
+ )
else:
if 'batchnorm' in parent_name or 'bn' in parent_name:
if name == 'scale':
@@ -85,7 +89,8 @@ def jax_param_types(param_shapes: spec.ParameterShapeTree,
param_types[name] = spec.ParameterType.BATCH_NORM_BIAS
else:
raise ValueError(
- f'Unrecognized batch norm parameter: {parent_name}/{name}.')
+ f'Unrecognized batch norm parameter: {parent_name}/{name}.'
+ )
elif 'layernorm' in parent_name or 'ln' in parent_name:
if name == 'scale':
param_types[name] = spec.ParameterType.LAYER_NORM_SCALE
@@ -93,7 +98,8 @@ def jax_param_types(param_shapes: spec.ParameterShapeTree,
param_types[name] = spec.ParameterType.LAYER_NORM_BIAS
else:
raise ValueError(
- f'Unrecognized layer norm parameter: {parent_name}/{name}.')
+ f'Unrecognized layer norm parameter: {parent_name}/{name}.'
+ )
elif 'conv' in parent_name:
if 'bias' in name:
param_types[name] = spec.ParameterType.BIAS
@@ -102,8 +108,9 @@ def jax_param_types(param_shapes: spec.ParameterShapeTree,
# Note that this is exact equality, not contained in, because
# flax.linen.Embed names the embedding parameter "embedding"
# https://github.com/google/flax/blob/main/flax/linen/linear.py#L604.
- elif ('embedding' in name or
- ('embedding' in parent_name and name == 'kernel')):
+ elif 'embedding' in name or (
+ 'embedding' in parent_name and name == 'kernel'
+ ):
param_types[name] = spec.ParameterType.EMBEDDING
elif 'attention' in parent_name:
if name == 'bias':
@@ -122,7 +129,8 @@ def jax_param_types(param_shapes: spec.ParameterShapeTree,
param_types[name] = spec.ParameterType.ATTENTION_QKV
else:
raise ValueError(
- f'Unrecognized attention parameter: {parent_name}/{name}.')
+ f'Unrecognized attention parameter: {parent_name}/{name}.'
+ )
elif 'bias' in name:
param_types[name] = spec.ParameterType.BIAS
else:
diff --git a/algoperf/profiler.py b/algoperf/profiler.py
index fa2a1bee2..0e791d3a8 100644
--- a/algoperf/profiler.py
+++ b/algoperf/profiler.py
@@ -21,7 +21,6 @@ def _get_monotonic_time() -> float:
class Profiler:
-
def __init__(self, local_rank: Optional[int] = None) -> None:
self._local_rank = local_rank
@@ -41,7 +40,8 @@ def start(self, action_name: str) -> None:
pass
if action_name in self.current_actions:
raise ValueError(
- f'Attempted to start {action_name} which has already started.')
+ f'Attempted to start {action_name} which has already started.'
+ )
self.current_actions[action_name] = _get_monotonic_time()
def stop(self, action_name: str) -> None:
@@ -49,8 +49,10 @@ def stop(self, action_name: str) -> None:
pass
end_time = _get_monotonic_time()
if action_name not in self.current_actions:
- raise ValueError(f'Attempting to stop recording an action '
- f'({action_name}) which was never started.')
+ raise ValueError(
+ f'Attempting to stop recording an action '
+ f'({action_name}) which was never started.'
+ )
start_time = self.current_actions.pop(action_name)
duration = end_time - start_time
self.recorded_durations[action_name].append(duration)
@@ -64,16 +66,20 @@ def profile(self, action_name: str) -> Generator:
self.stop(action_name)
def _make_report(
- self
+ self,
) -> Tuple[List[Tuple[str, float, float, int, float, float]], int, float]:
total_duration = _get_monotonic_time() - self.start_time
- report = [(str(a),
- float(np.mean(d)),
- float(np.std(d)),
- len(d),
- float(np.sum(d)),
- 100.0 * float(np.sum(d)) / total_duration) for a,
- d in self.recorded_durations.items()]
+ report = [
+ (
+ str(a),
+ float(np.mean(d)),
+ float(np.std(d)),
+ len(d),
+ float(np.sum(d)),
+ 100.0 * float(np.sum(d)) / total_duration,
+ )
+ for a, d in self.recorded_durations.items()
+ ]
report.sort(key=lambda x: x[5], reverse=True)
total_calls = sum(x[3] for x in report)
return report, total_calls, total_duration
@@ -92,32 +98,42 @@ def log_row(action, mean, std, num_calls, total, per):
row += f' {num_calls:<15}\t| {total:<15}\t| {per:<15}\t|'
return row
- header_string = log_row('Action',
- 'Mean Duration (s)',
- 'Std Duration (s)',
- 'Num Calls',
- 'Total Time (s)',
- 'Percentage %')
+ header_string = log_row(
+ 'Action',
+ 'Mean Duration (s)',
+ 'Std Duration (s)',
+ 'Num Calls',
+ 'Total Time (s)',
+ 'Percentage %',
+ )
output_string_len = len(header_string.expandtabs())
sep_lines = f'{sep}{"-" * output_string_len}'
output_string += sep_lines + header_string + sep_lines
report, total_calls, total_duration = self._make_report()
- output_string += log_row('Total',
- '-----',
- '-----',
- f'{total_calls:}',
- f'{total_duration:.5}',
- '100 %')
+ output_string += log_row(
+ 'Total',
+ '-----',
+ '-----',
+ f'{total_calls:}',
+ f'{total_duration:.5}',
+ '100 %',
+ )
output_string += sep_lines
- for action, mean_duration, std_duration, num_calls, \
- total_duration, duration_per in report:
+ for (
+ action,
+ mean_duration,
+ std_duration,
+ num_calls,
+ total_duration,
+ duration_per,
+ ) in report:
output_string += log_row(
- action,
- f'{mean_duration:.5}',
- f'{std_duration:.5}',
- f'{num_calls}',
- f'{total_duration:.5}',
- f'{duration_per:.5}',
+ action,
+ f'{mean_duration:.5}',
+ f'{std_duration:.5}',
+ f'{num_calls}',
+ f'{total_duration:.5}',
+ f'{duration_per:.5}',
)
output_string += sep_lines
output_string += sep
@@ -125,7 +141,6 @@ def log_row(action, mean, std, num_calls, total, per):
class PassThroughProfiler(Profiler):
-
def start(self, action_name: str) -> None:
pass
diff --git a/algoperf/pytorch_utils.py b/algoperf/pytorch_utils.py
index b81d2969a..429e4d1e2 100644
--- a/algoperf/pytorch_utils.py
+++ b/algoperf/pytorch_utils.py
@@ -12,10 +12,12 @@
from algoperf import spec
from algoperf.profiler import Profiler
-from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \
- BatchNorm as ConformerBatchNorm
-from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import \
- BatchNorm as DeepspeechBatchNorm
+from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import (
+ BatchNorm as ConformerBatchNorm,
+)
+from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import (
+ BatchNorm as DeepspeechBatchNorm,
+)
def pytorch_setup() -> Tuple[bool, int, torch.device, int]:
@@ -61,12 +63,13 @@ def sync_ddp_time(time: float, device: torch.device) -> float:
return time_tensor.item()
-def update_batch_norm_fn(module: spec.ParameterContainer,
- update_batch_norm: bool) -> None:
+def update_batch_norm_fn(
+ module: spec.ParameterContainer, update_batch_norm: bool
+) -> None:
bn_layers = (
- torch.nn.modules.batchnorm._BatchNorm, # PyTorch BN base class.
- ConformerBatchNorm, # Custom BN class for conformer model.
- DeepspeechBatchNorm, # Custom BN class for deepspeech model.
+ torch.nn.modules.batchnorm._BatchNorm, # PyTorch BN base class.
+ ConformerBatchNorm, # Custom BN class for conformer model.
+ DeepspeechBatchNorm, # Custom BN class for deepspeech model.
)
if isinstance(module, bn_layers):
if not update_batch_norm:
diff --git a/algoperf/random_utils.py b/algoperf/random_utils.py
index a579976ad..41f4b6b41 100644
--- a/algoperf/random_utils.py
+++ b/algoperf/random_utils.py
@@ -10,8 +10,9 @@
import jax.random as jax_rng
except (ImportError, ModuleNotFoundError):
logging.warning(
- 'Could not import jax.random for the submission runner, falling back to '
- 'numpy random_utils.')
+ 'Could not import jax.random for the submission runner, falling back to '
+ 'numpy random_utils.'
+ )
jax_rng = None
FLAGS = flags.FLAGS
@@ -54,8 +55,9 @@ def _PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name
def _check_jax_install() -> None:
if jax_rng is None:
raise ValueError(
- 'Must install jax to use the jax RNG library, or use PyTorch and pass '
- '--framework=pytorch to use the Numpy version instead.')
+ 'Must install jax to use the jax RNG library, or use PyTorch and pass '
+ '--framework=pytorch to use the Numpy version instead.'
+ )
def fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]:
diff --git a/algoperf/spec.py b/algoperf/spec.py
index 9670dcb76..5f7b930af 100644
--- a/algoperf/spec.py
+++ b/algoperf/spec.py
@@ -5,10 +5,10 @@
import functools
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union
-from absl import logging
import jax
-from torch import nn
import torch.nn.functional as F
+from absl import logging
+from torch import nn
class LossType(enum.Enum):
@@ -53,7 +53,6 @@ class ParameterType(enum.Enum):
# Define this so that if using pytree iteration utilities, can iterate over the
# model shapes pytree without iterating over the shape tuples.
class ShapeTuple:
-
def __init__(self, shape_tuple):
self.shape_tuple = shape_tuple
@@ -64,19 +63,22 @@ def __eq__(self, other):
return self.shape_tuple == other.shape_tuple
-Shape = Union[Tuple[int],
- Tuple[int, int],
- Tuple[int, int, int],
- Tuple[int, int, int, int],
- ShapeTuple]
+Shape = Union[
+ Tuple[int],
+ Tuple[int, int],
+ Tuple[int, int, int],
+ Tuple[int, int, int, int],
+ ShapeTuple,
+]
ParameterShapeTree = Dict[str, Dict[str, Shape]]
# If necessary, these can be zipped together easily given they have the same
# structure, to get an iterator over pairs of leaves.
ParameterKey = str
# Dicts can be arbitrarily nested.
-ParameterContainer = Union[Dict[ParameterKey, Dict[ParameterKey, Tensor]],
- nn.Module]
+ParameterContainer = Union[
+ Dict[ParameterKey, Dict[ParameterKey, Tensor]], nn.Module
+]
ParameterTypeTree = Dict[ParameterKey, Dict[ParameterKey, ParameterType]]
RandomState = Any # Union[jax.random.PRNGKey, int, bytes, ...]
@@ -92,7 +94,6 @@ def __eq__(self, other):
class Workload(metaclass=abc.ABCMeta):
-
def __init__(self, *args, **kwargs) -> None:
del args
del kwargs
@@ -107,8 +108,9 @@ def target_metric_name(self) -> str:
"""The name of the target metric (useful for scoring/processing code)."""
@abc.abstractmethod
- def has_reached_validation_target(self, eval_result: Dict[str,
- float]) -> bool:
+ def has_reached_validation_target(
+ self, eval_result: Dict[str, float]
+ ) -> bool:
"""Return whether or not the workload validation goal has been reached."""
@abc.abstractmethod
@@ -117,14 +119,15 @@ def has_reached_test_target(self, eval_result: Dict[str, float]) -> bool:
@abc.abstractmethod
def _build_input_queue(
- self,
- data_rng: RandomState,
- split: str,
- data_dir: str,
- global_batch_size: int,
- cache: Optional[bool] = None,
- repeat_final_dataset: Optional[bool] = None,
- num_batches: Optional[int] = None) -> Iterator[Dict[str, Any]]:
+ self,
+ data_rng: RandomState,
+ split: str,
+ data_dir: str,
+ global_batch_size: int,
+ cache: Optional[bool] = None,
+ repeat_final_dataset: Optional[bool] = None,
+ num_batches: Optional[int] = None,
+ ) -> Iterator[Dict[str, Any]]:
"""Build the input queue for the workload data.
This is the only function that is NOT allowed to be called by submitters.
@@ -213,8 +216,9 @@ def param_shapes(self):
"""The shapes of the parameters in the workload model."""
if self._param_shapes is None:
raise ValueError(
- 'This should not happen, workload.init_model_fn() should be called '
- 'before workload.param_shapes!')
+ 'This should not happen, workload.init_model_fn() should be called '
+ 'before workload.param_shapes!'
+ )
return self._param_shapes
@property
@@ -222,8 +226,9 @@ def model_params_types(self):
"""The types of the parameters in the workload model."""
if self._param_types is None:
raise ValueError(
- 'This should not happen, workload.init_model_fn() should be called '
- 'before workload.param_types!')
+ 'This should not happen, workload.init_model_fn() should be called '
+ 'before workload.param_types!'
+ )
return self._param_types
@abc.abstractmethod
@@ -234,10 +239,12 @@ def is_output_params(self, param_key: ParameterKey) -> bool:
# Tuple[RandomState, Optional[float], Optional[float]],
# ParameterContainer]
@abc.abstractmethod
- def init_model_fn(self,
- rng: RandomState,
- dropout_rate: Optional[float] = None,
- aux_dropout_rate: Optional[float] = None) -> ModelInitState:
+ def init_model_fn(
+ self,
+ rng: RandomState,
+ dropout_rate: Optional[float] = None,
+ aux_dropout_rate: Optional[float] = None,
+ ) -> ModelInitState:
"""Return (initial_params, initial_model_state)."""
# ModelFn = Callable[
@@ -251,30 +258,35 @@ def init_model_fn(self,
# float],
# Tensor]
@abc.abstractmethod
- def model_fn(self,
- params: ParameterContainer,
- augmented_and_preprocessed_input_batch: Dict[str, Tensor],
- model_state: ModelAuxiliaryState,
- mode: ForwardPassMode,
- rng: RandomState,
- update_batch_norm: bool,
- dropout_rate: float) -> Tuple[Tensor, ModelAuxiliaryState]:
+ def model_fn(
+ self,
+ params: ParameterContainer,
+ augmented_and_preprocessed_input_batch: Dict[str, Tensor],
+ model_state: ModelAuxiliaryState,
+ mode: ForwardPassMode,
+ rng: RandomState,
+ update_batch_norm: bool,
+ dropout_rate: float,
+ ) -> Tuple[Tensor, ModelAuxiliaryState]:
"""Return logits_batch"""
# Possible side effect of updating BN.
- def output_activation_fn(self, logits_batch: Tensor,
- framework: str) -> Tensor:
+ def output_activation_fn(
+ self, logits_batch: Tensor, framework: str
+ ) -> Tensor:
"""Turn logits into probabilities, according to the loss_type property."""
if framework not in ['pytorch', 'jax']:
raise ValueError(
- f'`framework` has to be either `pytorch` or `jax`, got {framework}.')
+ f'`framework` has to be either `pytorch` or `jax`, got {framework}.'
+ )
activation_fn = {
- LossType.MEAN_SQUARED_ERROR: lambda z: z,
- LossType.MEAN_ABSOLUTE_ERROR: lambda z: z,
+ LossType.MEAN_SQUARED_ERROR: lambda z: z,
+ LossType.MEAN_ABSOLUTE_ERROR: lambda z: z,
}
is_pytorch = framework == 'pytorch' # If False, framework == 'jax'.
softmax_fn = (
- functools.partial(F.softmax, dim=-1) if is_pytorch else jax.nn.softmax)
+ functools.partial(F.softmax, dim=-1) if is_pytorch else jax.nn.softmax
+ )
sigmoid_fn = F.sigmoid if is_pytorch else jax.nn.sigmoid
activation_fn[LossType.SOFTMAX_CROSS_ENTROPY] = softmax_fn
activation_fn[LossType.SIGMOID_CROSS_ENTROPY] = sigmoid_fn
@@ -286,12 +298,13 @@ def output_activation_fn(self, logits_batch: Tensor,
# `update_params`.
@abc.abstractmethod
def loss_fn(
- self,
- # Dense or one-hot labels, or a tuple of (tensor, padding) for speech.
- label_batch: Union[Tuple[Tensor, Tensor], Tensor],
- logits_batch: Union[Tuple[Tensor, Tensor], Tensor],
- mask_batch: Optional[Tensor] = None,
- label_smoothing: float = 0.0) -> Dict[str, Tensor]: # differentiable
+ self,
+ # Dense or one-hot labels, or a tuple of (tensor, padding) for speech.
+ label_batch: Union[Tuple[Tensor, Tensor], Tensor],
+ logits_batch: Union[Tuple[Tensor, Tensor], Tensor],
+ mask_batch: Optional[Tensor] = None,
+ label_smoothing: float = 0.0,
+ ) -> Dict[str, Tensor]: # differentiable
"""Evaluate the (masked) loss function at (label_batch, logits_batch).
Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of
@@ -300,48 +313,54 @@ def loss_fn(
"""
@abc.abstractmethod
- def _eval_model_on_split(self,
- split: str,
- num_examples: int,
- global_batch_size: int,
- params: ParameterContainer,
- model_state: ModelAuxiliaryState,
- rng: RandomState,
- data_dir: str,
- global_step: int = 0) -> Dict[str, float]:
+ def _eval_model_on_split(
+ self,
+ split: str,
+ num_examples: int,
+ global_batch_size: int,
+ params: ParameterContainer,
+ model_state: ModelAuxiliaryState,
+ rng: RandomState,
+ data_dir: str,
+ global_step: int = 0,
+ ) -> Dict[str, float]:
"""Evaluate the model on a given dataset split, return final scalars."""
- def eval_model(self,
- global_batch_size: int,
- params: ParameterContainer,
- model_state: ModelAuxiliaryState,
- rng: RandomState,
- data_dir: str,
- imagenet_v2_data_dir: Optional[str],
- global_step: int) -> Dict[str, float]:
+ def eval_model(
+ self,
+ global_batch_size: int,
+ params: ParameterContainer,
+ model_state: ModelAuxiliaryState,
+ rng: RandomState,
+ data_dir: str,
+ imagenet_v2_data_dir: Optional[str],
+ global_step: int,
+ ) -> Dict[str, float]:
"""Run a full evaluation of the model."""
logging.info('Evaluating on the training split.')
train_metrics = self._eval_model_on_split(
- split='eval_train',
- num_examples=self.num_eval_train_examples,
- global_batch_size=global_batch_size,
- params=params,
- model_state=model_state,
- rng=rng,
- data_dir=data_dir,
- global_step=global_step)
+ split='eval_train',
+ num_examples=self.num_eval_train_examples,
+ global_batch_size=global_batch_size,
+ params=params,
+ model_state=model_state,
+ rng=rng,
+ data_dir=data_dir,
+ global_step=global_step,
+ )
eval_metrics = {'train/' + k: v for k, v in train_metrics.items()}
# We always require a validation set.
logging.info('Evaluating on the validation split.')
validation_metrics = self._eval_model_on_split(
- 'validation',
- num_examples=self.num_validation_examples,
- global_batch_size=global_batch_size,
- params=params,
- model_state=model_state,
- rng=rng,
- data_dir=data_dir,
- global_step=global_step)
+ 'validation',
+ num_examples=self.num_validation_examples,
+ global_batch_size=global_batch_size,
+ params=params,
+ model_state=model_state,
+ rng=rng,
+ data_dir=data_dir,
+ global_step=global_step,
+ )
for k, v in validation_metrics.items():
eval_metrics['validation/' + k] = v
eval_metrics['validation/num_examples'] = self.num_validation_examples
@@ -350,14 +369,15 @@ def eval_model(self,
if self.num_test_examples is not None:
logging.info('Evaluating on the test split.')
test_metrics = self._eval_model_on_split(
- 'test',
- num_examples=self.num_test_examples,
- global_batch_size=global_batch_size,
- params=params,
- model_state=model_state,
- rng=rng,
- data_dir=imagenet_v2_data_dir if imagenet_v2_data_dir else data_dir,
- global_step=global_step)
+ 'test',
+ num_examples=self.num_test_examples,
+ global_batch_size=global_batch_size,
+ params=params,
+ model_state=model_state,
+ rng=rng,
+ data_dir=imagenet_v2_data_dir if imagenet_v2_data_dir else data_dir,
+ global_step=global_step,
+ )
for k, v in test_metrics.items():
eval_metrics['test/' + k] = v
eval_metrics['test/num_examples'] = self.num_test_examples
@@ -374,27 +394,32 @@ class TrainingCompleteError(Exception):
# Training algorithm track submission functions, to be filled in by the
# submitter.
-InitOptimizerFn = Callable[[
+InitOptimizerFn = Callable[
+ [
Workload,
ParameterContainer,
ModelAuxiliaryState,
Hyperparameters,
- RandomState
-],
- OptimizerState]
-
-
-def init_optimizer_state(workload: Workload,
- model_params: ParameterContainer,
- model_state: ModelAuxiliaryState,
- hyperparameters: Hyperparameters,
- rng: RandomState) -> OptimizerState:
+ RandomState,
+ ],
+ OptimizerState,
+]
+
+
+def init_optimizer_state(
+ workload: Workload,
+ model_params: ParameterContainer,
+ model_state: ModelAuxiliaryState,
+ hyperparameters: Hyperparameters,
+ rng: RandomState,
+) -> OptimizerState:
# return initial_optimizer_state
pass
UpdateReturn = Tuple[OptimizerState, ParameterContainer, ModelAuxiliaryState]
-UpdateParamsFn = Callable[[
+UpdateParamsFn = Callable[
+ [
Workload,
ParameterContainer,
ParameterTypeTree,
@@ -406,9 +431,10 @@ def init_optimizer_state(workload: Workload,
List[Tuple[int, float]],
int,
RandomState,
- Optional[Dict[str, Any]]
-],
- UpdateReturn]
+ Optional[Dict[str, Any]],
+ ],
+ UpdateReturn,
+]
# Each call to this function is considered a "step".
@@ -417,23 +443,26 @@ def init_optimizer_state(workload: Workload,
# and if has not actually achieved the goal then it will be considered as not
# achieved the goal and get an infinite time score. Most submissions will likely
# wait until the next free eval and not use this functionality.
-def update_params(workload: Workload,
- current_param_container: ParameterContainer,
- current_params_types: ParameterTypeTree,
- model_state: ModelAuxiliaryState,
- hyperparameters: Hyperparameters,
- batch: Dict[str, Tensor],
- loss_type: LossType,
- optimizer_state: OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: RandomState,
- train_state: Optional[Dict[str, Any]] = None) -> UpdateReturn:
+def update_params(
+ workload: Workload,
+ current_param_container: ParameterContainer,
+ current_params_types: ParameterTypeTree,
+ model_state: ModelAuxiliaryState,
+ hyperparameters: Hyperparameters,
+ batch: Dict[str, Tensor],
+ loss_type: LossType,
+ optimizer_state: OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: RandomState,
+ train_state: Optional[Dict[str, Any]] = None,
+) -> UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
pass
-PrepareForEvalFn = Callable[[
+PrepareForEvalFn = Callable[
+ [
Workload,
ParameterContainer,
ParameterTypeTree,
@@ -443,27 +472,31 @@ def update_params(workload: Workload,
OptimizerState,
List[Tuple[int, float]],
int,
- RandomState
-],
- UpdateReturn]
+ RandomState,
+ ],
+ UpdateReturn,
+]
# Prepare model and optimizer for evaluation.
-def prepare_for_eval(workload: Workload,
- current_param_container: ParameterContainer,
- current_params_types: ParameterTypeTree,
- model_state: ModelAuxiliaryState,
- hyperparameters: Hyperparameters,
- loss_type: LossType,
- optimizer_state: OptimizerState,
- eval_results: List[Tuple[int, float]],
- global_step: int,
- rng: RandomState) -> UpdateReturn:
+def prepare_for_eval(
+ workload: Workload,
+ current_param_container: ParameterContainer,
+ current_params_types: ParameterTypeTree,
+ model_state: ModelAuxiliaryState,
+ hyperparameters: Hyperparameters,
+ loss_type: LossType,
+ optimizer_state: OptimizerState,
+ eval_results: List[Tuple[int, float]],
+ global_step: int,
+ rng: RandomState,
+) -> UpdateReturn:
"""Return (updated_optimizer_state, updated_params, updated_model_state)."""
pass
-DataSelectionFn = Callable[[
+DataSelectionFn = Callable[
+ [
Workload,
Iterator[Dict[str, Any]],
OptimizerState,
@@ -471,21 +504,24 @@ def prepare_for_eval(workload: Workload,
LossType,
Hyperparameters,
int,
- RandomState
-],
- Tuple[Tensor, Tensor]]
+ RandomState,
+ ],
+ Tuple[Tensor, Tensor],
+]
# Not allowed to update the model parameters, hyperparameters, global step, or
# optimzier state.
-def data_selection(workload: Workload,
- input_queue: Iterator[Dict[str, Any]],
- optimizer_state: OptimizerState,
- current_param_container: ParameterContainer,
- model_state: ModelAuxiliaryState,
- hyperparameters: Hyperparameters,
- global_step: int,
- rng: RandomState) -> Dict[str, Tensor]:
+def data_selection(
+ workload: Workload,
+ input_queue: Iterator[Dict[str, Any]],
+ optimizer_state: OptimizerState,
+ current_param_container: ParameterContainer,
+ model_state: ModelAuxiliaryState,
+ hyperparameters: Hyperparameters,
+ global_step: int,
+ rng: RandomState,
+) -> Dict[str, Tensor]:
"""Select data from the infinitely repeating, pre-shuffled input queue.
Each element of the queue is a batch of training examples and labels.
diff --git a/algoperf/workloads/cifar/cifar_jax/input_pipeline.py b/algoperf/workloads/cifar/cifar_jax/input_pipeline.py
index 728d05f29..3d831c4af 100644
--- a/algoperf/workloads/cifar/cifar_jax/input_pipeline.py
+++ b/algoperf/workloads/cifar/cifar_jax/input_pipeline.py
@@ -17,13 +17,15 @@
from algoperf.data_utils import shard_and_maybe_pad_np
-def preprocess_for_train(image: spec.Tensor,
- rng: spec.RandomState,
- mean_rgb: Tuple[float, float, float],
- stddev_rgb: Tuple[float, float, float],
- crop_size: int,
- padding_size: int,
- dtype: tf.DType = tf.float32) -> spec.Tensor:
+def preprocess_for_train(
+ image: spec.Tensor,
+ rng: spec.RandomState,
+ mean_rgb: Tuple[float, float, float],
+ stddev_rgb: Tuple[float, float, float],
+ crop_size: int,
+ padding_size: int,
+ dtype: tf.DType = tf.float32,
+) -> spec.Tensor:
"""Preprocesses the given image for training.
Args:
@@ -44,20 +46,23 @@ def preprocess_for_train(image: spec.Tensor,
flip_rng = rng[1, :]
image_shape = tf.shape(image)
- image = tf.image.resize_with_crop_or_pad(image,
- image_shape[0] + padding_size,
- image_shape[1] + padding_size)
+ image = tf.image.resize_with_crop_or_pad(
+ image, image_shape[0] + padding_size, image_shape[1] + padding_size
+ )
image = tf.image.stateless_random_crop(
- image, (crop_size, crop_size, 3), seed=crop_rng)
+ image, (crop_size, crop_size, 3), seed=crop_rng
+ )
image = tf.image.stateless_random_flip_left_right(image, seed=flip_rng)
image = normalize_image(image, mean_rgb, stddev_rgb, dtype=dtype)
return image
-def preprocess_for_eval(image: spec.Tensor,
- mean_rgb: Tuple[float, float, float],
- stddev_rgb: Tuple[float, float, float],
- dtype: tf.DType = tf.float32) -> spec.Tensor:
+def preprocess_for_eval(
+ image: spec.Tensor,
+ mean_rgb: Tuple[float, float, float],
+ stddev_rgb: Tuple[float, float, float],
+ dtype: tf.DType = tf.float32,
+) -> spec.Tensor:
"""Preprocesses the given image for evaluation.
Args:
@@ -74,10 +79,12 @@ def preprocess_for_eval(image: spec.Tensor,
return image
-def normalize_image(image: spec.Tensor,
- mean_rgb: Tuple[float, float, float],
- stddev_rgb: Tuple[float, float, float],
- dtype=tf.float32) -> spec.Tensor:
+def normalize_image(
+ image: spec.Tensor,
+ mean_rgb: Tuple[float, float, float],
+ stddev_rgb: Tuple[float, float, float],
+ dtype=tf.float32,
+) -> spec.Tensor:
image = tf.image.convert_image_dtype(image, dtype)
image -= tf.constant(mean_rgb, shape=[1, 1, 3], dtype=image.dtype)
image /= tf.constant(stddev_rgb, shape=[1, 1, 3], dtype=image.dtype)
@@ -85,17 +92,17 @@ def normalize_image(image: spec.Tensor,
def create_split(
- split: str,
- dataset_builder: tfds.core.dataset_builder.DatasetBuilder,
- rng: spec.RandomState,
- global_batch_size: int,
- train: bool,
- mean_rgb: Tuple[float, float, float],
- stddev_rgb: Tuple[float, float, float],
- cache: bool = False,
- repeat_final_dataset: bool = False,
- crop_size: int = 32,
- padding_size: int = 4,
+ split: str,
+ dataset_builder: tfds.core.dataset_builder.DatasetBuilder,
+ rng: spec.RandomState,
+ global_batch_size: int,
+ train: bool,
+ mean_rgb: Tuple[float, float, float],
+ stddev_rgb: Tuple[float, float, float],
+ cache: bool = False,
+ repeat_final_dataset: bool = False,
+ crop_size: int = 32,
+ padding_size: int = 4,
) -> Iterator[Dict[str, spec.Tensor]]:
"""Creates a split from the CIFAR-10 dataset using TensorFlow Datasets."""
shuffle_rng, preprocess_rng = jax.random.split(rng, 2)
@@ -104,14 +111,17 @@ def preprocess_example(example_index, example):
dtype = tf.float32
if train:
per_step_preprocess_rng = tf.random.experimental.stateless_fold_in(
- tf.cast(preprocess_rng, tf.int64), example_index)
- image = preprocess_for_train(example['image'],
- per_step_preprocess_rng,
- mean_rgb,
- stddev_rgb,
- crop_size,
- padding_size,
- dtype)
+ tf.cast(preprocess_rng, tf.int64), example_index
+ )
+ image = preprocess_for_train(
+ example['image'],
+ per_step_preprocess_rng,
+ mean_rgb,
+ stddev_rgb,
+ crop_size,
+ padding_size,
+ dtype,
+ )
else:
image = preprocess_for_eval(example['image'], mean_rgb, stddev_rgb, dtype)
return {'inputs': image, 'targets': example['label']}
@@ -132,7 +142,8 @@ def preprocess_example(example_index, example):
# index that we can fold into the RNG seed.
ds = ds.enumerate()
ds = ds.map(
- preprocess_example, num_parallel_calls=tf.data.experimental.AUTOTUNE)
+ preprocess_example, num_parallel_calls=tf.data.experimental.AUTOTUNE
+ )
ds = ds.batch(global_batch_size, drop_remainder=train)
if repeat_final_dataset:
@@ -144,32 +155,36 @@ def preprocess_example(example_index, example):
def create_input_iter(
- split: str,
- dataset_builder: tfds.core.dataset_builder.DatasetBuilder,
- rng: spec.RandomState,
- global_batch_size: int,
- mean_rgb: Tuple[float, float, float],
- stddev_rgb: Tuple[float, float, float],
- crop_size: int,
- padding_size: int,
- train: bool,
- cache: bool,
- repeat_final_dataset: bool) -> Iterator[Dict[str, spec.Tensor]]:
+ split: str,
+ dataset_builder: tfds.core.dataset_builder.DatasetBuilder,
+ rng: spec.RandomState,
+ global_batch_size: int,
+ mean_rgb: Tuple[float, float, float],
+ stddev_rgb: Tuple[float, float, float],
+ crop_size: int,
+ padding_size: int,
+ train: bool,
+ cache: bool,
+ repeat_final_dataset: bool,
+) -> Iterator[Dict[str, spec.Tensor]]:
ds = create_split(
- split,
- dataset_builder,
- rng,
- global_batch_size,
- train=train,
- mean_rgb=mean_rgb,
- stddev_rgb=stddev_rgb,
- cache=cache,
- repeat_final_dataset=repeat_final_dataset,
- crop_size=crop_size,
- padding_size=padding_size)
+ split,
+ dataset_builder,
+ rng,
+ global_batch_size,
+ train=train,
+ mean_rgb=mean_rgb,
+ stddev_rgb=stddev_rgb,
+ cache=cache,
+ repeat_final_dataset=repeat_final_dataset,
+ crop_size=crop_size,
+ padding_size=padding_size,
+ )
it = map(
- functools.partial(
- shard_and_maybe_pad_np, global_batch_size=global_batch_size),
- ds)
+ functools.partial(
+ shard_and_maybe_pad_np, global_batch_size=global_batch_size
+ ),
+ ds,
+ )
it = jax_utils.prefetch_to_device(it, 2)
return it
diff --git a/algoperf/workloads/cifar/cifar_jax/models.py b/algoperf/workloads/cifar/cifar_jax/models.py
index 957079272..8d034796f 100644
--- a/algoperf/workloads/cifar/cifar_jax/models.py
+++ b/algoperf/workloads/cifar/cifar_jax/models.py
@@ -25,48 +25,52 @@ class ResNet(nn.Module):
act: Callable = nn.relu
@nn.compact
- def __call__(self,
- x: spec.Tensor,
- update_batch_norm: bool = True,
- use_running_average_bn: bool = None) -> spec.Tensor:
+ def __call__(
+ self,
+ x: spec.Tensor,
+ update_batch_norm: bool = True,
+ use_running_average_bn: bool = None,
+ ) -> spec.Tensor:
conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype)
# Preserve default behavior for backwards compatibility
if use_running_average_bn is None:
use_running_average_bn = not update_batch_norm
norm = functools.partial(
- nn.BatchNorm,
- use_running_average=use_running_average_bn,
- momentum=0.9,
- epsilon=1e-5,
- dtype=self.dtype)
+ nn.BatchNorm,
+ use_running_average=use_running_average_bn,
+ momentum=0.9,
+ epsilon=1e-5,
+ dtype=self.dtype,
+ )
x = conv(
- self.num_filters, (3, 3), (1, 1),
- padding=[(1, 1), (1, 1)],
- name='Conv_init')(
- x)
+ self.num_filters,
+ (3, 3),
+ (1, 1),
+ padding=[(1, 1), (1, 1)],
+ name='Conv_init',
+ )(x)
x = norm(name='BatchNorm_init')(x)
x = nn.relu(x)
for i, block_size in enumerate(self.stage_sizes):
for j in range(block_size):
strides = (2, 2) if i > 0 and j == 0 else (1, 1)
x = self.block_cls(
- self.num_filters * 2**i,
- strides=strides,
- conv=conv,
- norm=norm,
- act=self.act)(
- x)
+ self.num_filters * 2**i,
+ strides=strides,
+ conv=conv,
+ norm=norm,
+ act=self.act,
+ )(x)
x = nn.avg_pool(x, (4, 4), strides=(4, 4))
x = jnp.mean(x, axis=(1, 2))
x = nn.Dense(
- self.num_classes,
- kernel_init=nn.initializers.normal(),
- dtype=self.dtype)(
- x)
+ self.num_classes, kernel_init=nn.initializers.normal(), dtype=self.dtype
+ )(x)
return x
ResNet18 = functools.partial(
- ResNet, stage_sizes=(2, 2, 2, 2), block_cls=ResNetBlock)
+ ResNet, stage_sizes=(2, 2, 2, 2), block_cls=ResNetBlock
+)
diff --git a/algoperf/workloads/cifar/cifar_jax/workload.py b/algoperf/workloads/cifar/cifar_jax/workload.py
index c6cc50fbf..bc26e3899 100644
--- a/algoperf/workloads/cifar/cifar_jax/workload.py
+++ b/algoperf/workloads/cifar/cifar_jax/workload.py
@@ -3,32 +3,30 @@
import functools
from typing import Any, Dict, Iterator, Optional, Tuple
-from flax import jax_utils
-from flax import linen as nn
-from flax.core import pop
import jax
-from jax import lax
import jax.numpy as jnp
import optax
import tensorflow_datasets as tfds
+from flax import jax_utils
+from flax import linen as nn
+from flax.core import pop
+from jax import lax
-from algoperf import param_utils
-from algoperf import spec
+from algoperf import param_utils, spec
from algoperf.workloads.cifar.cifar_jax import models
from algoperf.workloads.cifar.cifar_jax.input_pipeline import create_input_iter
from algoperf.workloads.cifar.workload import BaseCifarWorkload
class CifarWorkload(BaseCifarWorkload):
-
def _build_cifar_dataset(
- self,
- data_rng: spec.RandomState,
- split: str,
- data_dir: str,
- batch_size: int,
- cache: Optional[bool] = None,
- repeat_final_dataset: Optional[bool] = None
+ self,
+ data_rng: spec.RandomState,
+ split: str,
+ data_dir: str,
+ batch_size: int,
+ cache: Optional[bool] = None,
+ repeat_final_dataset: Optional[bool] = None,
) -> Iterator[Dict[str, spec.Tensor]]:
ds_builder = tfds.builder('cifar10:3.0.2', data_dir=data_dir)
train = split == 'train'
@@ -38,38 +36,38 @@ def _build_cifar_dataset(
elif split == 'validation':
split = f'train[{self.num_train_examples}:]'
ds = create_input_iter(
- split,
- ds_builder,
- data_rng,
- batch_size,
- self.train_mean,
- self.train_stddev,
- self.crop_size,
- self.padding_size,
- train=train,
- cache=not train if cache is None else cache,
- repeat_final_dataset=repeat_final_dataset)
+ split,
+ ds_builder,
+ data_rng,
+ batch_size,
+ self.train_mean,
+ self.train_stddev,
+ self.crop_size,
+ self.padding_size,
+ train=train,
+ cache=not train if cache is None else cache,
+ repeat_final_dataset=repeat_final_dataset,
+ )
return ds
def _build_input_queue(
- self,
- data_rng: spec.RandomState,
- split: str,
- data_dir: str,
- global_batch_size: int,
- cache: Optional[bool] = None,
- repeat_final_dataset: Optional[bool] = None,
- num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]:
+ self,
+ data_rng: spec.RandomState,
+ split: str,
+ data_dir: str,
+ global_batch_size: int,
+ cache: Optional[bool] = None,
+ repeat_final_dataset: Optional[bool] = None,
+ num_batches: Optional[int] = None,
+ ) -> Iterator[Dict[str, spec.Tensor]]:
del num_batches
- return self._build_cifar_dataset(data_rng,
- split,
- data_dir,
- global_batch_size,
- cache,
- repeat_final_dataset)
+ return self._build_cifar_dataset(
+ data_rng, split, data_dir, global_batch_size, cache, repeat_final_dataset
+ )
def sync_batch_stats(
- self, model_state: spec.ModelAuxiliaryState) -> spec.ModelAuxiliaryState:
+ self, model_state: spec.ModelAuxiliaryState
+ ) -> spec.ModelAuxiliaryState:
"""Sync the batch statistics across replicas."""
# An axis_name is passed to pmap which can then be used by pmean.
# In this case each device has its own version of the batch statistics
@@ -85,8 +83,9 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
model = model_cls(num_classes=self._num_classes, dtype=jnp.float32)
self._model = model
input_shape = (1, 32, 32, 3)
- variables = jax.jit(model.init)({'params': rng},
- jnp.ones(input_shape, model.dtype))
+ variables = jax.jit(model.init)(
+ {'params': rng}, jnp.ones(input_shape, model.dtype)
+ )
model_state, params = pop(variables, 'params')
self._param_shapes = param_utils.jax_param_shapes(params)
self._param_types = param_utils.jax_param_types(self._param_shapes)
@@ -98,43 +97,46 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key == 'Dense_0'
def model_fn(
- self,
- params: spec.ParameterContainer,
- augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- mode: spec.ForwardPassMode,
- rng: spec.RandomState,
- update_batch_norm: bool,
- use_running_average_bn: Optional[bool] = None
+ self,
+ params: spec.ParameterContainer,
+ augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ mode: spec.ForwardPassMode,
+ rng: spec.RandomState,
+ update_batch_norm: bool,
+ use_running_average_bn: Optional[bool] = None,
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del mode
del rng
variables = {'params': params, **model_state}
if update_batch_norm:
logits, new_model_state = self._model.apply(
- variables,
- augmented_and_preprocessed_input_batch['inputs'],
- update_batch_norm=update_batch_norm,
- mutable=['batch_stats'],
- use_running_average_bn=use_running_average_bn)
+ variables,
+ augmented_and_preprocessed_input_batch['inputs'],
+ update_batch_norm=update_batch_norm,
+ mutable=['batch_stats'],
+ use_running_average_bn=use_running_average_bn,
+ )
return logits, new_model_state
else:
logits = self._model.apply(
- variables,
- augmented_and_preprocessed_input_batch['inputs'],
- update_batch_norm=update_batch_norm,
- mutable=False,
- use_running_average_bn=use_running_average_bn)
+ variables,
+ augmented_and_preprocessed_input_batch['inputs'],
+ update_batch_norm=update_batch_norm,
+ mutable=False,
+ use_running_average_bn=use_running_average_bn,
+ )
return logits, model_state
# Does NOT apply regularization, which is left to the submitter to do in
# `update_params`.
def loss_fn(
- self,
- label_batch: spec.Tensor, # Dense or one-hot labels.
- logits_batch: spec.Tensor,
- mask_batch: Optional[spec.Tensor] = None,
- label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable
+ self,
+ label_batch: spec.Tensor, # Dense or one-hot labels.
+ logits_batch: spec.Tensor,
+ mask_batch: Optional[spec.Tensor] = None,
+ label_smoothing: float = 0.0,
+ ) -> Dict[str, spec.Tensor]: # differentiable
"""Evaluate the (masked) loss function at (label_batch, logits_batch).
Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of
@@ -144,7 +146,8 @@ def loss_fn(
one_hot_targets = jax.nn.one_hot(label_batch, self._num_classes)
smoothed_targets = optax.smooth_labels(one_hot_targets, label_smoothing)
per_example_losses = -jnp.sum(
- smoothed_targets * nn.log_softmax(logits_batch), axis=-1)
+ smoothed_targets * nn.log_softmax(logits_batch), axis=-1
+ )
# `mask_batch` is assumed to be shape [batch].
if mask_batch is not None:
per_example_losses *= mask_batch
@@ -153,51 +156,53 @@ def loss_fn(
n_valid_examples = len(per_example_losses)
summed_loss = per_example_losses.sum()
return {
- 'summed': summed_loss,
- 'n_valid_examples': n_valid_examples,
- 'per_example': per_example_losses,
+ 'summed': summed_loss,
+ 'n_valid_examples': n_valid_examples,
+ 'per_example': per_example_losses,
}
- def _compute_metrics(self,
- logits: spec.Tensor,
- labels: spec.Tensor,
- weights: spec.Tensor) -> Dict[str, spec.Tensor]:
+ def _compute_metrics(
+ self, logits: spec.Tensor, labels: spec.Tensor, weights: spec.Tensor
+ ) -> Dict[str, spec.Tensor]:
summed_loss = self.loss_fn(labels, logits, weights)['summed']
# Number of correct predictions.
accuracy = jnp.sum((jnp.argmax(logits, -1) == labels) * weights)
metrics = {
- 'loss': summed_loss,
- 'accuracy': accuracy,
+ 'loss': summed_loss,
+ 'accuracy': accuracy,
}
metrics = lax.psum(metrics, axis_name='batch')
return metrics
@functools.partial(
- jax.pmap,
- axis_name='batch',
- in_axes=(None, 0, 0, 0, None),
- static_broadcasted_argnums=(0,))
+ jax.pmap,
+ axis_name='batch',
+ in_axes=(None, 0, 0, 0, None),
+ static_broadcasted_argnums=(0,),
+ )
def _eval_model(
- self,
- params: spec.ParameterContainer,
- batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- rng: spec.RandomState) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]:
+ self,
+ params: spec.ParameterContainer,
+ batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ rng: spec.RandomState,
+ ) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]:
"""Return the mean accuracy and loss as a dict."""
logits, _ = self.model_fn(
- params,
- batch,
- model_state,
- spec.ForwardPassMode.EVAL,
- rng,
- update_batch_norm=False)
+ params,
+ batch,
+ model_state,
+ spec.ForwardPassMode.EVAL,
+ rng,
+ update_batch_norm=False,
+ )
weights = batch.get('weights')
if weights is None:
weights = jnp.ones(len(logits))
return self._compute_metrics(logits, batch['targets'], weights)
def _normalize_eval_metrics(
- self, num_examples: int, total_metrics: Dict[str,
- Any]) -> Dict[str, float]:
+ self, num_examples: int, total_metrics: Dict[str, Any]
+ ) -> Dict[str, float]:
"""Normalize eval metrics."""
return jax.tree.map(lambda x: float(x[0] / num_examples), total_metrics)
diff --git a/algoperf/workloads/cifar/cifar_pytorch/models.py b/algoperf/workloads/cifar/cifar_pytorch/models.py
index e6a7a8a81..6beef89e6 100644
--- a/algoperf/workloads/cifar/cifar_pytorch/models.py
+++ b/algoperf/workloads/cifar/cifar_pytorch/models.py
@@ -12,23 +12,26 @@
from algoperf import spec
from algoperf.init_utils import pytorch_default_init
-from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import \
- BasicBlock
-from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import \
- Bottleneck
+from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import (
+ BasicBlock,
+)
+from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import (
+ Bottleneck,
+)
from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import conv1x1
class ResNet(nn.Module):
-
- def __init__(self,
- block: Type[Union[BasicBlock, Bottleneck]],
- layers: List[int],
- num_classes: int = 10,
- groups: int = 1,
- width_per_group: int = 64,
- replace_stride_with_dilation: Optional[List[bool]] = None,
- norm_layer: Optional[Callable[..., nn.Module]] = None) -> None:
+ def __init__(
+ self,
+ block: Type[Union[BasicBlock, Bottleneck]],
+ layers: List[int],
+ num_classes: int = 10,
+ groups: int = 1,
+ width_per_group: int = 64,
+ replace_stride_with_dilation: Optional[List[bool]] = None,
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
+ ) -> None:
super().__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
@@ -42,21 +45,26 @@ def __init__(self,
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError(
- 'replace_stride_with_dilation should be None '
- f'or a 3-element tuple, got {replace_stride_with_dilation}')
+ 'replace_stride_with_dilation should be None '
+ f'or a 3-element tuple, got {replace_stride_with_dilation}'
+ )
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(
- 3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
+ 3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False
+ )
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(
- block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
+ block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]
+ )
self.layer3 = self._make_layer(
- block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
+ block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]
+ )
self.layer4 = self._make_layer(
- block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
+ block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]
+ )
self.fc = nn.Linear(512 * block.expansion, num_classes)
self.reset_parameters()
@@ -68,7 +76,7 @@ def reset_parameters(self) -> None:
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
nn.init.normal_(self.fc.weight, std=1e-2)
- nn.init.constant_(self.fc.bias, 0.)
+ nn.init.constant_(self.fc.bias, 0.0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros,
@@ -81,12 +89,14 @@ def reset_parameters(self) -> None:
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
- def _make_layer(self,
- block: Type[Union[BasicBlock, Bottleneck]],
- planes: int,
- blocks: int,
- stride: int = 1,
- dilate: bool = False) -> nn.Sequential:
+ def _make_layer(
+ self,
+ block: Type[Union[BasicBlock, Bottleneck]],
+ planes: int,
+ blocks: int,
+ stride: int = 1,
+ dilate: bool = False,
+ ) -> nn.Sequential:
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
@@ -95,32 +105,39 @@ def _make_layer(self,
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = torch.nn.Sequential(
- collections.OrderedDict([
- ("conv", conv1x1(self.inplanes, planes * block.expansion,
- stride)),
- ("bn", norm_layer(planes * block.expansion)),
- ]))
+ collections.OrderedDict(
+ [
+ ('conv', conv1x1(self.inplanes, planes * block.expansion, stride)),
+ ('bn', norm_layer(planes * block.expansion)),
+ ]
+ )
+ )
layers = []
layers.append(
- block(self.inplanes,
- planes,
- stride,
- downsample,
- self.groups,
- self.base_width,
- previous_dilation,
- norm_layer))
+ block(
+ self.inplanes,
+ planes,
+ stride,
+ downsample,
+ self.groups,
+ self.base_width,
+ previous_dilation,
+ norm_layer,
+ )
+ )
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(
- block(
- self.inplanes,
- planes,
- groups=self.groups,
- base_width=self.base_width,
- dilation=self.dilation,
- norm_layer=norm_layer))
+ block(
+ self.inplanes,
+ planes,
+ groups=self.groups,
+ base_width=self.base_width,
+ dilation=self.dilation,
+ norm_layer=norm_layer,
+ )
+ )
return nn.Sequential(*layers)
diff --git a/algoperf/workloads/cifar/cifar_pytorch/workload.py b/algoperf/workloads/cifar/cifar_pytorch/workload.py
index d05131c27..d7e858226 100644
--- a/algoperf/workloads/cifar/cifar_pytorch/workload.py
+++ b/algoperf/workloads/cifar/cifar_pytorch/workload.py
@@ -23,7 +23,6 @@
class CifarWorkload(BaseCifarWorkload):
-
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
# Is set in submission_runner.py for workloads with PyTorch evaluation
@@ -34,7 +33,8 @@ def __init__(self, *args, **kwargs) -> None:
def eval_num_workers(self) -> int:
if self._eval_num_workers is None:
raise ValueError(
- 'eval_num_workers property must be set before workload is used.')
+ 'eval_num_workers property must be set before workload is used.'
+ )
return self._eval_num_workers
@eval_num_workers.setter
@@ -42,47 +42,54 @@ def eval_num_workers(self, eval_num_workers: int):
self._eval_num_workers = eval_num_workers
def _build_dataset(
- self,
- data_rng: spec.RandomState,
- split: str,
- data_dir: str,
- global_batch_size: int,
- cache: Optional[bool] = None,
- repeat_final_dataset: Optional[bool] = None
+ self,
+ data_rng: spec.RandomState,
+ split: str,
+ data_dir: str,
+ global_batch_size: int,
+ cache: Optional[bool] = None,
+ repeat_final_dataset: Optional[bool] = None,
) -> torch.utils.data.DataLoader:
del cache
del repeat_final_dataset
is_train = split == 'train'
- normalize = transforms.Compose([
+ normalize = transforms.Compose(
+ [
transforms.ToTensor(),
transforms.Normalize(mean=self.train_mean, std=self.train_stddev),
- ])
+ ]
+ )
eval_transform_config = normalize
- train_transform_config = transforms.Compose([
+ train_transform_config = transforms.Compose(
+ [
transforms.RandomCrop(
- size=self.crop_size,
- padding=self.padding_size,
+ size=self.crop_size,
+ padding=self.padding_size,
),
transforms.RandomHorizontalFlip(),
normalize,
- ])
+ ]
+ )
transform = train_transform_config if is_train else eval_transform_config
dataset = CIFAR10(
- root=data_dir,
- train=split in ['train', 'eval_train', 'validation'],
- download=False,
- transform=transform)
+ root=data_dir,
+ train=split in ['train', 'eval_train', 'validation'],
+ download=False,
+ transform=transform,
+ )
assert self.num_train_examples + self.num_validation_examples == 50000
indices = list(range(50000))
indices_split = {
- 'train': indices[:self.num_train_examples],
- 'validation': indices[self.num_train_examples:],
+ 'train': indices[: self.num_train_examples],
+ 'validation': indices[self.num_train_examples :],
}
if split == 'eval_train':
train_indices = indices_split['train']
random.Random(int(data_rng[0])).shuffle(train_indices)
- indices_split['eval_train'] = train_indices[:self.num_eval_train_examples]
+ indices_split['eval_train'] = train_indices[
+ : self.num_eval_train_examples
+ ]
if split in indices_split:
dataset = torch.utils.data.Subset(dataset, indices_split[split])
@@ -92,30 +99,34 @@ def _build_dataset(
ds_iter_batch_size = per_device_batch_size
if is_train:
sampler = torch.utils.data.distributed.DistributedSampler(
- dataset, num_replicas=N_GPUS, rank=RANK, shuffle=True)
+ dataset, num_replicas=N_GPUS, rank=RANK, shuffle=True
+ )
else:
sampler = data_utils.DistributedEvalSampler(
- dataset, num_replicas=N_GPUS, rank=RANK, shuffle=False)
+ dataset, num_replicas=N_GPUS, rank=RANK, shuffle=False
+ )
else:
ds_iter_batch_size = global_batch_size
dataloader = torch.utils.data.DataLoader(
- dataset,
- batch_size=ds_iter_batch_size,
- shuffle=not USE_PYTORCH_DDP and is_train,
- sampler=sampler,
- num_workers=4 if is_train else self.eval_num_workers,
- pin_memory=True,
- drop_last=is_train)
+ dataset,
+ batch_size=ds_iter_batch_size,
+ shuffle=not USE_PYTORCH_DDP and is_train,
+ sampler=sampler,
+ num_workers=4 if is_train else self.eval_num_workers,
+ pin_memory=True,
+ drop_last=is_train,
+ )
dataloader = data_utils.PrefetchedWrapper(dataloader, DEVICE)
dataloader = data_utils.cycle(dataloader, custom_sampler=USE_PYTORCH_DDP)
return dataloader
def init_model_fn(
- self,
- rng: spec.RandomState,
- dropout_rate: Optional[float] = None,
- aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
+ self,
+ rng: spec.RandomState,
+ dropout_rate: Optional[float] = None,
+ aux_dropout_rate: Optional[float] = None,
+ ) -> spec.ModelInitState:
"""Dropout is unused."""
del dropout_rate
del aux_dropout_rate
@@ -143,30 +154,34 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key in ['fc.weight', 'fc.bias']
def model_fn(
- self,
- params: spec.ParameterContainer,
- augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- mode: spec.ForwardPassMode,
- rng: spec.RandomState,
- update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ self,
+ params: spec.ParameterContainer,
+ augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ mode: spec.ForwardPassMode,
+ rng: spec.RandomState,
+ update_batch_norm: bool,
+ ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del rng
model = params
if mode == spec.ForwardPassMode.EVAL:
if update_batch_norm:
raise ValueError(
- 'Batch norm statistics cannot be updated during evaluation.')
+ 'Batch norm statistics cannot be updated during evaluation.'
+ )
model.eval()
if mode == spec.ForwardPassMode.TRAIN:
model.train()
model.apply(
- functools.partial(
- pytorch_utils.update_batch_norm_fn,
- update_batch_norm=update_batch_norm))
+ functools.partial(
+ pytorch_utils.update_batch_norm_fn,
+ update_batch_norm=update_batch_norm,
+ )
+ )
contexts = {
- spec.ForwardPassMode.EVAL: torch.no_grad,
- spec.ForwardPassMode.TRAIN: contextlib.nullcontext,
+ spec.ForwardPassMode.EVAL: torch.no_grad,
+ spec.ForwardPassMode.TRAIN: contextlib.nullcontext,
}
with contexts[mode]():
logits_batch = model(augmented_and_preprocessed_input_batch['inputs'])
@@ -175,11 +190,12 @@ def model_fn(
# Does NOT apply regularization, which is left to the submitter to do in
# `update_params`.
def loss_fn(
- self,
- label_batch: spec.Tensor, # Dense or one-hot labels.
- logits_batch: spec.Tensor,
- mask_batch: Optional[spec.Tensor] = None,
- label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable
+ self,
+ label_batch: spec.Tensor, # Dense or one-hot labels.
+ logits_batch: spec.Tensor,
+ mask_batch: Optional[spec.Tensor] = None,
+ label_smoothing: float = 0.0,
+ ) -> Dict[str, spec.Tensor]: # differentiable
"""Evaluate the (masked) loss function at (label_batch, logits_batch).
Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of
@@ -187,10 +203,11 @@ def loss_fn(
(not synced across devices).
"""
per_example_losses = F.cross_entropy(
- logits_batch,
- label_batch,
- reduction='none',
- label_smoothing=label_smoothing)
+ logits_batch,
+ label_batch,
+ reduction='none',
+ label_smoothing=label_smoothing,
+ )
# `mask_batch` is assumed to be shape [batch].
if mask_batch is not None:
per_example_losses *= mask_batch
@@ -199,25 +216,27 @@ def loss_fn(
n_valid_examples = len(per_example_losses)
summed_loss = per_example_losses.sum()
return {
- 'summed': summed_loss,
- 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE),
- 'per_example': per_example_losses,
+ 'summed': summed_loss,
+ 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE),
+ 'per_example': per_example_losses,
}
def _eval_model(
- self,
- params: spec.ParameterContainer,
- batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- rng: spec.RandomState) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]:
+ self,
+ params: spec.ParameterContainer,
+ batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ rng: spec.RandomState,
+ ) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]:
"""Return the mean accuracy and loss as a dict."""
logits, _ = self.model_fn(
- params,
- batch,
- model_state,
- spec.ForwardPassMode.EVAL,
- rng,
- update_batch_norm=False)
+ params,
+ batch,
+ model_state,
+ spec.ForwardPassMode.EVAL,
+ rng,
+ update_batch_norm=False,
+ )
targets = batch['targets']
weights = batch.get('weights')
if weights is None:
@@ -229,8 +248,8 @@ def _eval_model(
return {'accuracy': accuracy, 'loss': summed_loss}
def _normalize_eval_metrics(
- self, num_examples: int, total_metrics: Dict[str,
- Any]) -> Dict[str, float]:
+ self, num_examples: int, total_metrics: Dict[str, Any]
+ ) -> Dict[str, float]:
if USE_PYTORCH_DDP:
for metric in total_metrics.values():
dist.all_reduce(metric)
diff --git a/algoperf/workloads/cifar/workload.py b/algoperf/workloads/cifar/workload.py
index c0d565108..61880fbfa 100644
--- a/algoperf/workloads/cifar/workload.py
+++ b/algoperf/workloads/cifar/workload.py
@@ -15,7 +15,6 @@
class BaseCifarWorkload(spec.Workload):
-
_num_classes: int = 10
@property
@@ -23,8 +22,9 @@ def target_metric_name(self) -> str:
"""The name of the target metric (useful for scoring/processing code)."""
return 'accuracy'
- def has_reached_validation_target(self, eval_result: Dict[str,
- float]) -> bool:
+ def has_reached_validation_target(
+ self, eval_result: Dict[str, float]
+ ) -> bool:
return eval_result['validation/accuracy'] > self.validation_target_value
@property
@@ -51,8 +51,9 @@ def num_eval_train_examples(self) -> int:
# Round up from num_validation_examples (which is the default for
# num_eval_train_examples) to the next multiple of eval_batch_size, so that
# we don't have to extract the correctly sized subset of the training data.
- rounded_up_multiple = math.ceil(self.num_validation_examples /
- self.eval_batch_size)
+ rounded_up_multiple = math.ceil(
+ self.num_validation_examples / self.eval_batch_size
+ )
return rounded_up_multiple * self.eval_batch_size
@property
@@ -93,37 +94,35 @@ def eval_period_time_sec(self) -> int:
return 600 # 10 mins.
def _build_dataset(
- self,
- data_rng: spec.RandomState,
- split: str,
- data_dir: str,
- global_batch_size: int,
- cache: Optional[bool] = None,
- repeat_final_dataset: Optional[bool] = None
+ self,
+ data_rng: spec.RandomState,
+ split: str,
+ data_dir: str,
+ global_batch_size: int,
+ cache: Optional[bool] = None,
+ repeat_final_dataset: Optional[bool] = None,
) -> Iterator[Dict[str, spec.Tensor]]:
raise NotImplementedError
def _build_input_queue(
- self,
- data_rng: spec.RandomState,
- split: str,
- data_dir: str,
- global_batch_size: int,
- cache: Optional[bool] = None,
- repeat_final_dataset: Optional[bool] = None,
- num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]:
+ self,
+ data_rng: spec.RandomState,
+ split: str,
+ data_dir: str,
+ global_batch_size: int,
+ cache: Optional[bool] = None,
+ repeat_final_dataset: Optional[bool] = None,
+ num_batches: Optional[int] = None,
+ ) -> Iterator[Dict[str, spec.Tensor]]:
del num_batches
if split == 'test':
if not cache:
raise ValueError('cache must be True for split=test.')
if not repeat_final_dataset:
raise ValueError('repeat_final_dataset must be True for split=test.')
- return self._build_dataset(data_rng,
- split,
- data_dir,
- global_batch_size,
- cache,
- repeat_final_dataset)
+ return self._build_dataset(
+ data_rng, split, data_dir, global_batch_size, cache, repeat_final_dataset
+ )
@property
def step_hint(self) -> int:
@@ -133,39 +132,43 @@ def step_hint(self) -> int:
return 4883
def _eval_model(
- self,
- params: spec.ParameterContainer,
- batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- rng: spec.RandomState) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]:
+ self,
+ params: spec.ParameterContainer,
+ batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ rng: spec.RandomState,
+ ) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]:
raise NotImplementedError
@abc.abstractmethod
def _normalize_eval_metrics(
- self, num_examples: int, total_metrics: Dict[str,
- Any]) -> Dict[str, float]:
+ self, num_examples: int, total_metrics: Dict[str, Any]
+ ) -> Dict[str, float]:
"""Normalize eval metrics."""
- def _eval_model_on_split(self,
- split: str,
- num_examples: int,
- global_batch_size: int,
- params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- rng: spec.RandomState,
- data_dir: str,
- global_step: int = 0) -> Dict[str, float]:
+ def _eval_model_on_split(
+ self,
+ split: str,
+ num_examples: int,
+ global_batch_size: int,
+ params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ rng: spec.RandomState,
+ data_dir: str,
+ global_step: int = 0,
+ ) -> Dict[str, float]:
"""Run a full evaluation of the model."""
del global_step
data_rng, model_rng = prng.split(rng, 2)
if split not in self._eval_iters:
self._eval_iters[split] = self._build_input_queue(
- data_rng=data_rng,
- split=split,
- data_dir=data_dir,
- global_batch_size=global_batch_size,
- cache=True,
- repeat_final_dataset=True)
+ data_rng=data_rng,
+ split=split,
+ data_dir=data_dir,
+ global_batch_size=global_batch_size,
+ cache=True,
+ repeat_final_dataset=True,
+ )
num_batches = int(math.ceil(num_examples / global_batch_size))
num_devices = max(torch.cuda.device_count(), jax.local_device_count())
@@ -174,10 +177,9 @@ def _eval_model_on_split(self,
batch = next(self._eval_iters[split])
per_device_model_rngs = prng.split(model_rng, num_devices)
# We already average these metrics across devices inside _compute_metrics.
- synced_metrics = self._eval_model(params,
- batch,
- model_state,
- per_device_model_rngs)
+ synced_metrics = self._eval_model(
+ params, batch, model_state, per_device_model_rngs
+ )
for metric_name, metric_value in synced_metrics.items():
if metric_name not in eval_metrics:
eval_metrics[metric_name] = 0.0
diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py
index 4a91a80b8..706c2b51a 100644
--- a/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py
+++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/models.py
@@ -1,9 +1,10 @@
"""A JAX implementation of DLRM-Small."""
+
from typing import Sequence
import flax.linen as nn
-from jax import nn as jnn
import jax.numpy as jnp
+from jax import nn as jnn
from algoperf.jax_utils import Dropout
@@ -32,7 +33,6 @@ class DLRMResNet(nn.Module):
@nn.compact
def __call__(self, x, train, dropout_rate=DROPOUT_RATE):
-
bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1)
cat_features = jnp.asarray(cat_features, dtype=jnp.int32)
@@ -40,20 +40,18 @@ def __call__(self, x, train, dropout_rate=DROPOUT_RATE):
mlp_bottom_dims = self.mlp_bottom_dims
bot_mlp_input = nn.Dense(
- mlp_bottom_dims[0],
- kernel_init=jnn.initializers.glorot_uniform(),
- bias_init=jnn.initializers.normal(stddev=1.0 / mlp_bottom_dims[0]**0.5),
- )(
- bot_mlp_input)
+ mlp_bottom_dims[0],
+ kernel_init=jnn.initializers.glorot_uniform(),
+ bias_init=jnn.initializers.normal(stddev=1.0 / mlp_bottom_dims[0] ** 0.5),
+ )(bot_mlp_input)
bot_mlp_input = nn.relu(bot_mlp_input)
for dense_dim in mlp_bottom_dims[1:]:
x = nn.Dense(
- dense_dim,
- kernel_init=jnn.initializers.glorot_uniform(),
- bias_init=jnn.initializers.normal(stddev=1.0 / dense_dim**0.5),
- )(
- bot_mlp_input)
+ dense_dim,
+ kernel_init=jnn.initializers.glorot_uniform(),
+ bias_init=jnn.initializers.normal(stddev=1.0 / dense_dim**0.5),
+ )(bot_mlp_input)
bot_mlp_input += nn.relu(x)
base_init_fn = jnn.initializers.uniform(scale=1.0)
@@ -63,34 +61,38 @@ def __call__(self, x, train, dropout_rate=DROPOUT_RATE):
def scaled_init(key, shape, dtype=jnp.float_):
return base_init_fn(key, shape, dtype) / jnp.sqrt(self.vocab_size)
- embedding_table = self.param('embedding_table',
- scaled_init, [self.vocab_size, self.embed_dim])
+ embedding_table = self.param(
+ 'embedding_table', scaled_init, [self.vocab_size, self.embed_dim]
+ )
embed_features = embedding_table[idx_lookup]
batch_size = bot_mlp_input.shape[0]
- embed_features = jnp.reshape(embed_features,
- (batch_size, 26 * self.embed_dim))
+ embed_features = jnp.reshape(
+ embed_features, (batch_size, 26 * self.embed_dim)
+ )
top_mlp_input = jnp.concatenate([bot_mlp_input, embed_features], axis=1)
mlp_input_dim = top_mlp_input.shape[1]
mlp_top_dims = self.mlp_top_dims
num_layers_top = len(mlp_top_dims)
top_mlp_input = nn.Dense(
- mlp_top_dims[0],
- kernel_init=jnn.initializers.normal(
- stddev=jnp.sqrt(2.0 / (mlp_input_dim + mlp_top_dims[0]))),
- bias_init=jnn.initializers.normal(
- stddev=jnp.sqrt(1.0 / mlp_top_dims[0])))(
- top_mlp_input)
+ mlp_top_dims[0],
+ kernel_init=jnn.initializers.normal(
+ stddev=jnp.sqrt(2.0 / (mlp_input_dim + mlp_top_dims[0]))
+ ),
+ bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / mlp_top_dims[0])),
+ )(top_mlp_input)
top_mlp_input = nn.relu(top_mlp_input)
for layer_idx, fan_out in list(enumerate(mlp_top_dims))[1:-1]:
fan_in = mlp_top_dims[layer_idx - 1]
x = nn.Dense(
- fan_out,
- kernel_init=jnn.initializers.normal(
- stddev=jnp.sqrt(2.0 / (fan_in + fan_out))),
- bias_init=jnn.initializers.normal(
- stddev=jnp.sqrt(1.0 / mlp_top_dims[layer_idx])))(
- top_mlp_input)
+ fan_out,
+ kernel_init=jnn.initializers.normal(
+ stddev=jnp.sqrt(2.0 / (fan_in + fan_out))
+ ),
+ bias_init=jnn.initializers.normal(
+ stddev=jnp.sqrt(1.0 / mlp_top_dims[layer_idx])
+ ),
+ )(top_mlp_input)
x = nn.relu(x)
if dropout_rate and layer_idx == num_layers_top - 2:
x = Dropout(dropout_rate, deterministic=not train)(x, rate=dropout_rate)
@@ -98,11 +100,12 @@ def scaled_init(key, shape, dtype=jnp.float_):
# In the DLRM model the last layer width is always 1. We can hardcode that
# below.
logits = nn.Dense(
- 1,
- kernel_init=jnn.initializers.normal(
- stddev=jnp.sqrt(2.0 / (mlp_top_dims[-2] + 1))),
- bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0)))(
- top_mlp_input)
+ 1,
+ kernel_init=jnn.initializers.normal(
+ stddev=jnp.sqrt(2.0 / (mlp_top_dims[-2] + 1))
+ ),
+ bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0)),
+ )(top_mlp_input)
return logits
@@ -118,16 +121,18 @@ def dot_interact(concat_features):
batch_size = concat_features.shape[0]
# Interact features, select upper or lower-triangular portion, and reshape.
- xactions = jnp.matmul(concat_features,
- jnp.transpose(concat_features, [0, 2, 1]))
+ xactions = jnp.matmul(
+ concat_features, jnp.transpose(concat_features, [0, 2, 1])
+ )
feature_dim = xactions.shape[-1]
indices = jnp.array(jnp.triu_indices(feature_dim))
num_elems = indices.shape[1]
indices = jnp.tile(indices, [1, batch_size])
indices0 = jnp.reshape(
- jnp.tile(jnp.reshape(jnp.arange(batch_size), [-1, 1]), [1, num_elems]),
- [1, -1])
+ jnp.tile(jnp.reshape(jnp.arange(batch_size), [-1, 1]), [1, num_elems]),
+ [1, -1],
+ )
indices = tuple(jnp.concatenate((indices0, indices), 0))
activations = xactions[indices]
activations = jnp.reshape(activations, [batch_size, -1])
@@ -156,25 +161,24 @@ class DlrmSmall(nn.Module):
@nn.compact
def __call__(self, x, train, dropout_rate=DROPOUT_RATE):
-
bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1)
cat_features = jnp.asarray(cat_features, dtype=jnp.int32)
# Bottom MLP.
for dense_dim in self.mlp_bottom_dims:
bot_mlp_input = nn.Dense(
- dense_dim,
- kernel_init=jnn.initializers.glorot_uniform(),
- bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / dense_dim)),
- )(
- bot_mlp_input)
+ dense_dim,
+ kernel_init=jnn.initializers.glorot_uniform(),
+ bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / dense_dim)),
+ )(bot_mlp_input)
bot_mlp_input = nn.relu(bot_mlp_input)
if self.use_layer_norm:
bot_mlp_input = nn.LayerNorm()(bot_mlp_input)
bot_mlp_output = bot_mlp_input
batch_size = bot_mlp_output.shape[0]
- feature_stack = jnp.reshape(bot_mlp_output,
- [batch_size, -1, self.embed_dim])
+ feature_stack = jnp.reshape(
+ bot_mlp_output, [batch_size, -1, self.embed_dim]
+ )
# Embedding table look-up.
idx_lookup = jnp.reshape(cat_features, [-1]) % self.vocab_size
@@ -187,38 +191,45 @@ def __call__(self, x, train, dropout_rate=DROPOUT_RATE):
def scaled_init(key, shape, dtype=jnp.float_):
return jnn.initializers.uniform(scale=1.0)(key, shape, dtype) * scale
- embedding_table = self.param('embedding_table',
- scaled_init, [self.vocab_size, self.embed_dim])
+ embedding_table = self.param(
+ 'embedding_table', scaled_init, [self.vocab_size, self.embed_dim]
+ )
idx_lookup = jnp.reshape(idx_lookup, [-1])
embed_features = embedding_table[idx_lookup]
- embed_features = jnp.reshape(embed_features,
- [batch_size, -1, self.embed_dim])
+ embed_features = jnp.reshape(
+ embed_features, [batch_size, -1, self.embed_dim]
+ )
if self.use_layer_norm:
embed_features = nn.LayerNorm()(embed_features)
feature_stack = jnp.concatenate([feature_stack, embed_features], axis=1)
dot_interact_output = dot_interact(concat_features=feature_stack)
- top_mlp_input = jnp.concatenate([bot_mlp_output, dot_interact_output],
- axis=-1)
+ top_mlp_input = jnp.concatenate(
+ [bot_mlp_output, dot_interact_output], axis=-1
+ )
mlp_input_dim = top_mlp_input.shape[1]
mlp_top_dims = self.mlp_top_dims
num_layers_top = len(mlp_top_dims)
for layer_idx, fan_out in enumerate(mlp_top_dims):
fan_in = mlp_input_dim if layer_idx == 0 else mlp_top_dims[layer_idx - 1]
top_mlp_input = nn.Dense(
- fan_out,
- kernel_init=jnn.initializers.normal(
- stddev=jnp.sqrt(2.0 / (fan_in + fan_out))),
- bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / fan_out)))(
- top_mlp_input)
+ fan_out,
+ kernel_init=jnn.initializers.normal(
+ stddev=jnp.sqrt(2.0 / (fan_in + fan_out))
+ ),
+ bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0 / fan_out)),
+ )(top_mlp_input)
if layer_idx < (num_layers_top - 1):
top_mlp_input = nn.relu(top_mlp_input)
if self.use_layer_norm:
top_mlp_input = nn.LayerNorm()(top_mlp_input)
- if (dropout_rate is not None and dropout_rate > 0.0 and
- layer_idx == num_layers_top - 2):
- top_mlp_input = Dropout(
- dropout_rate, deterministic=not train)(
- top_mlp_input, rate=dropout_rate)
+ if (
+ dropout_rate is not None
+ and dropout_rate > 0.0
+ and layer_idx == num_layers_top - 2
+ ):
+ top_mlp_input = Dropout(dropout_rate, deterministic=not train)(
+ top_mlp_input, rate=dropout_rate
+ )
logits = top_mlp_input
return logits
diff --git a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py
index d84d18d5c..283b3be8e 100644
--- a/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py
+++ b/algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py
@@ -3,26 +3,24 @@
import functools
from typing import Dict, Optional, Tuple
-from flax import jax_utils
import jax
import jax.numpy as jnp
import numpy as np
+from flax import jax_utils
-from algoperf import param_utils
-from algoperf import spec
+from algoperf import param_utils, spec
from algoperf.workloads.criteo1tb.criteo1tb_jax import models
-from algoperf.workloads.criteo1tb.workload import \
- BaseCriteo1TbDlrmSmallWorkload
+from algoperf.workloads.criteo1tb.workload import BaseCriteo1TbDlrmSmallWorkload
class Criteo1TbDlrmSmallWorkload(BaseCriteo1TbDlrmSmallWorkload):
-
@property
def eval_batch_size(self) -> int:
return 131_072
def _per_example_sigmoid_binary_cross_entropy(
- self, logits: spec.Tensor, targets: spec.Tensor) -> spec.Tensor:
+ self, logits: spec.Tensor, targets: spec.Tensor
+ ) -> spec.Tensor:
"""Computes the sigmoid binary cross entropy per example.
Args:
@@ -39,11 +37,12 @@ def _per_example_sigmoid_binary_cross_entropy(
# Does NOT apply regularization, which is left to the submitter to do in
# `update_params`.
def loss_fn(
- self,
- label_batch: spec.Tensor, # Dense (not one-hot) labels.
- logits_batch: spec.Tensor,
- mask_batch: Optional[spec.Tensor] = None,
- label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable
+ self,
+ label_batch: spec.Tensor, # Dense (not one-hot) labels.
+ logits_batch: spec.Tensor,
+ mask_batch: Optional[spec.Tensor] = None,
+ label_smoothing: float = 0.0,
+ ) -> Dict[str, spec.Tensor]: # differentiable
"""Evaluate the (masked) loss function at (label_batch, logits_batch).
Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of
@@ -55,7 +54,8 @@ def loss_fn(
label_batch = jnp.reshape(label_batch, (batch_size,))
logits_batch = jnp.reshape(logits_batch, (batch_size,))
per_example_losses = self._per_example_sigmoid_binary_cross_entropy(
- logits=logits_batch, targets=label_batch)
+ logits=logits_batch, targets=label_batch
+ )
if mask_batch is not None:
mask_batch = jnp.reshape(mask_batch, (batch_size,))
per_example_losses *= mask_batch
@@ -64,15 +64,15 @@ def loss_fn(
n_valid_examples = len(per_example_losses)
summed_loss = per_example_losses.sum()
return {
- 'summed': summed_loss,
- 'n_valid_examples': n_valid_examples,
- 'per_example': per_example_losses,
+ 'summed': summed_loss,
+ 'n_valid_examples': n_valid_examples,
+ 'per_example': per_example_losses,
}
def init_model_fn(
- self,
- rng: spec.RandomState,
- tabulate: Optional[bool] = False,
+ self,
+ rng: spec.RandomState,
+ tabulate: Optional[bool] = False,
) -> spec.ModelInitState:
"""Only dropout is used."""
if self.use_resnet:
@@ -81,13 +81,14 @@ def init_model_fn(
model_class = models.DlrmSmall
self._model = model_class(
- vocab_size=self.vocab_size,
- num_dense_features=self.num_dense_features,
- mlp_bottom_dims=self.mlp_bottom_dims,
- mlp_top_dims=self.mlp_top_dims,
- embed_dim=self.embed_dim,
- use_layer_norm=self.use_layer_norm,
- embedding_init_multiplier=self.embedding_init_multiplier)
+ vocab_size=self.vocab_size,
+ num_dense_features=self.num_dense_features,
+ mlp_bottom_dims=self.mlp_bottom_dims,
+ mlp_top_dims=self.mlp_top_dims,
+ embed_dim=self.embed_dim,
+ use_layer_norm=self.use_layer_norm,
+ embedding_init_multiplier=self.embedding_init_multiplier,
+ )
params_rng, _ = jax.random.split(rng)
init_fake_batch_size = 2
@@ -96,10 +97,12 @@ def init_model_fn(
input_size = num_dense_features + num_categorical_features
input_shape = (init_fake_batch_size, input_size)
init_fn = functools.partial(self._model.init, train=False)
- initial_variables = jax.jit(init_fn)({
+ initial_variables = jax.jit(init_fn)(
+ {
'params': params_rng,
- },
- jnp.ones(input_shape, jnp.float32))
+ },
+ jnp.ones(input_shape, jnp.float32),
+ )
initial_params = initial_variables['params']
self._param_shapes = param_utils.jax_param_shapes(initial_params)
self._param_types = param_utils.jax_param_types(self._param_shapes)
@@ -109,14 +112,14 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key == 'Dense_7'
def model_fn(
- self,
- params: spec.ParameterContainer,
- augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- mode: spec.ForwardPassMode,
- rng: spec.RandomState,
- update_batch_norm: bool,
- dropout_rate: float = models.DROPOUT_RATE
+ self,
+ params: spec.ParameterContainer,
+ augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ mode: spec.ForwardPassMode,
+ rng: spec.RandomState,
+ update_batch_norm: bool,
+ dropout_rate: float = models.DROPOUT_RATE,
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del update_batch_norm
@@ -130,35 +133,38 @@ def model_fn(
return logits_batch, None
@functools.partial(
- jax.pmap,
- axis_name='batch',
- in_axes=(None, 0, 0),
- static_broadcasted_argnums=(0,))
- def _eval_batch_pmapped(self,
- params: spec.ParameterContainer,
- batch: Dict[str, spec.Tensor]) -> spec.Tensor:
+ jax.pmap,
+ axis_name='batch',
+ in_axes=(None, 0, 0),
+ static_broadcasted_argnums=(0,),
+ )
+ def _eval_batch_pmapped(
+ self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor]
+ ) -> spec.Tensor:
logits, _ = self.model_fn(
- params,
- batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ params,
+ batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
weights = batch.get('weights')
if weights is None:
weights = jnp.ones(len(logits))
summed_loss = self.loss_fn(
- label_batch=batch['targets'], logits_batch=logits,
- mask_batch=weights)['summed']
+ label_batch=batch['targets'], logits_batch=logits, mask_batch=weights
+ )['summed']
return summed_loss
- def _eval_batch(self,
- params: spec.ParameterContainer,
- batch: Dict[str, spec.Tensor]) -> spec.Tensor:
+ def _eval_batch(
+ self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor]
+ ) -> spec.Tensor:
# We do NOT psum inside of _eval_batch_pmapped, so the returned tensor of
# shape (local_device_count,) will all be different values.
return np.array(
- self._eval_batch_pmapped(params, batch).sum(), dtype=np.float64)
+ self._eval_batch_pmapped(params, batch).sum(), dtype=np.float64
+ )
class Criteo1TbDlrmSmallTestWorkload(Criteo1TbDlrmSmallWorkload):
@@ -166,7 +172,6 @@ class Criteo1TbDlrmSmallTestWorkload(Criteo1TbDlrmSmallWorkload):
class Criteo1TbDlrmSmallLayerNormWorkload(Criteo1TbDlrmSmallWorkload):
-
@property
def use_layer_norm(self) -> bool:
"""Whether or not to use LayerNorm in the model."""
@@ -200,7 +205,6 @@ def test_target_value(self) -> float:
class Criteo1TbDlrmSmallEmbedInitWorkload(Criteo1TbDlrmSmallWorkload):
-
@property
def validation_target_value(self) -> float:
return 0.129657
diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models.py
index 7574de3a7..1906bf7ae 100644
--- a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models.py
+++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/models.py
@@ -5,14 +5,13 @@
import torch
from torch import nn
-from algoperf.pytorch_utils import CustomDropout
-from algoperf.pytorch_utils import SequentialWithDropout
+from algoperf.pytorch_utils import CustomDropout, SequentialWithDropout
DROPOUT_RATE = 0.0
class DenseBlock(nn.Module):
- """Dense block with optional residual connection.""" ""
+ """Dense block with optional residual connection.""" ''
def __init__(self, module, resnet=False):
super().__init__()
@@ -41,17 +40,20 @@ class DotInteract(nn.Module):
def __init__(self, num_sparse_features):
super().__init__()
- self.triu_indices = torch.triu_indices(num_sparse_features + 1,
- num_sparse_features + 1)
+ self.triu_indices = torch.triu_indices(
+ num_sparse_features + 1, num_sparse_features + 1
+ )
def forward(self, dense_features, sparse_features):
- combined_values = torch.cat((dense_features.unsqueeze(1), sparse_features),
- dim=1)
- interactions = torch.bmm(combined_values,
- torch.transpose(combined_values, 1, 2))
- interactions_flat = interactions[:,
- self.triu_indices[0],
- self.triu_indices[1]]
+ combined_values = torch.cat(
+ (dense_features.unsqueeze(1), sparse_features), dim=1
+ )
+ interactions = torch.bmm(
+ combined_values, torch.transpose(combined_values, 1, 2)
+ )
+ interactions_flat = interactions[
+ :, self.triu_indices[0], self.triu_indices[1]
+ ]
return torch.cat((dense_features, interactions_flat), dim=1)
@@ -66,15 +68,17 @@ class DLRMResNet(nn.Module):
embed_dim: embedding dimension.
"""
- def __init__(self,
- vocab_size,
- num_dense_features=13,
- num_sparse_features=26,
- mlp_bottom_dims=(256, 256, 256),
- mlp_top_dims=(256, 256, 256, 256, 1),
- embed_dim=128,
- use_layer_norm=False,
- embedding_init_multiplier=None):
+ def __init__(
+ self,
+ vocab_size,
+ num_dense_features=13,
+ num_sparse_features=26,
+ mlp_bottom_dims=(256, 256, 256),
+ mlp_top_dims=(256, 256, 256, 256, 1),
+ embed_dim=128,
+ use_layer_norm=False,
+ embedding_init_multiplier=None,
+ ):
super().__init__()
self.vocab_size = torch.tensor(vocab_size, dtype=torch.int32)
self.num_dense_features = num_dense_features
@@ -92,7 +96,8 @@ def __init__(self,
scale = 1.0 / torch.sqrt(self.vocab_size)
for i in range(num_chunks):
chunk = nn.Parameter(
- torch.Tensor(self.vocab_size // num_chunks, self.embed_dim))
+ torch.Tensor(self.vocab_size // num_chunks, self.embed_dim)
+ )
chunk.data.uniform_(0, 1)
chunk.data = scale * chunk.data
self.register_parameter(f'embedding_chunk_{i}', chunk)
@@ -115,11 +120,11 @@ def __init__(self,
for module in self.bot_mlp.modules():
if isinstance(module, nn.Linear):
- limit = math.sqrt(6. / (module.in_features + module.out_features))
+ limit = math.sqrt(6.0 / (module.in_features + module.out_features))
nn.init.uniform_(module.weight.data, -limit, limit)
- nn.init.normal_(module.bias.data,
- 0.,
- math.sqrt(1. / module.out_features))
+ nn.init.normal_(
+ module.bias.data, 0.0, math.sqrt(1.0 / module.out_features)
+ )
# Number of sparse features = 26
fan_in = (26 * self.embed_dim) + self.mlp_bottom_dims[-1]
@@ -144,19 +149,20 @@ def __init__(self,
for module in self.top_mlp.modules():
if isinstance(module, nn.Linear):
nn.init.normal_(
- module.weight.data,
- 0.,
- math.sqrt(2. / (module.in_features + module.out_features)))
- nn.init.normal_(module.bias.data,
- 0.,
- math.sqrt(1. / module.out_features))
+ module.weight.data,
+ 0.0,
+ math.sqrt(2.0 / (module.in_features + module.out_features)),
+ )
+ nn.init.normal_(
+ module.bias.data, 0.0, math.sqrt(1.0 / module.out_features)
+ )
def forward(self, x, dropout_rate=DROPOUT_RATE):
-
batch_size = x.shape[0]
dense_features, sparse_features = torch.split(
- x, [self.num_dense_features, self.num_sparse_features], 1)
+ x, [self.num_dense_features, self.num_sparse_features], 1
+ )
# Bottom MLP.
embedded_dense = self.bot_mlp(dense_features)
@@ -166,8 +172,9 @@ def forward(self, x, dropout_rate=DROPOUT_RATE):
idx_lookup = torch.reshape(sparse_features, [-1]) % self.vocab_size
embedding_table = torch.cat(self.embedding_table_chucks, dim=0)
embedded_sparse = embedding_table[idx_lookup]
- embedded_sparse = torch.reshape(embedded_sparse,
- [batch_size, 26 * self.embed_dim])
+ embedded_sparse = torch.reshape(
+ embedded_sparse, [batch_size, 26 * self.embed_dim]
+ )
top_mlp_input = torch.cat([embedded_dense, embedded_sparse], axis=1)
# Final MLP.
@@ -186,15 +193,17 @@ class DlrmSmall(nn.Module):
embed_dim: embedding dimension.
"""
- def __init__(self,
- vocab_size,
- num_dense_features=13,
- num_sparse_features=26,
- mlp_bottom_dims=(512, 256, 128),
- mlp_top_dims=(1024, 1024, 512, 256, 1),
- embed_dim=128,
- use_layer_norm=False,
- embedding_init_multiplier=None):
+ def __init__(
+ self,
+ vocab_size,
+ num_dense_features=13,
+ num_sparse_features=26,
+ mlp_bottom_dims=(512, 256, 128),
+ mlp_top_dims=(1024, 1024, 512, 256, 1),
+ embed_dim=128,
+ use_layer_norm=False,
+ embedding_init_multiplier=None,
+ ):
super().__init__()
self.vocab_size = torch.tensor(vocab_size, dtype=torch.int32)
self.num_dense_features = num_dense_features
@@ -218,7 +227,8 @@ def __init__(self,
for i in range(num_chunks):
chunk = nn.Parameter(
- torch.Tensor(self.vocab_size // num_chunks, self.embed_dim))
+ torch.Tensor(self.vocab_size // num_chunks, self.embed_dim)
+ )
chunk.data.uniform_(0, 1)
chunk.data = scale * chunk.data
self.register_parameter(f'embedding_chunk_{i}', chunk)
@@ -235,21 +245,24 @@ def __init__(self,
self.bot_mlp = nn.Sequential(*bottom_mlp_layers)
for module in self.bot_mlp.modules():
if isinstance(module, nn.Linear):
- limit = math.sqrt(6. / (module.in_features + module.out_features))
+ limit = math.sqrt(6.0 / (module.in_features + module.out_features))
nn.init.uniform_(module.weight.data, -limit, limit)
- nn.init.normal_(module.bias.data,
- 0.,
- math.sqrt(1. / module.out_features))
+ nn.init.normal_(
+ module.bias.data, 0.0, math.sqrt(1.0 / module.out_features)
+ )
- self.dot_interact = DotInteract(num_sparse_features=num_sparse_features,)
+ self.dot_interact = DotInteract(
+ num_sparse_features=num_sparse_features,
+ )
# TODO: Write down the formula here instead of the constant.
input_dims = 506
num_layers_top = len(self.mlp_top_dims)
top_mlp_layers = []
for layer_idx, fan_out in enumerate(self.mlp_top_dims):
- fan_in = input_dims if layer_idx == 0 \
- else self.mlp_top_dims[layer_idx - 1]
+ fan_in = (
+ input_dims if layer_idx == 0 else self.mlp_top_dims[layer_idx - 1]
+ )
top_mlp_layers.append(nn.Linear(fan_in, fan_out))
if layer_idx < (num_layers_top - 1):
top_mlp_layers.append(nn.ReLU(inplace=True))
@@ -265,19 +278,20 @@ def __init__(self,
for module in self.top_mlp.modules():
if isinstance(module, nn.Linear):
nn.init.normal_(
- module.weight.data,
- 0.,
- math.sqrt(2. / (module.in_features + module.out_features)))
- nn.init.normal_(module.bias.data,
- 0.,
- math.sqrt(1. / module.out_features))
+ module.weight.data,
+ 0.0,
+ math.sqrt(2.0 / (module.in_features + module.out_features)),
+ )
+ nn.init.normal_(
+ module.bias.data, 0.0, math.sqrt(1.0 / module.out_features)
+ )
def forward(self, x, dropout_rate=DROPOUT_RATE):
-
batch_size = x.shape[0]
dense_features, sparse_features = torch.split(
- x, [self.num_dense_features, self.num_sparse_features], 1)
+ x, [self.num_dense_features, self.num_sparse_features], 1
+ )
# Bottom MLP.
embedded_dense = self.bot_mlp(dense_features)
@@ -287,13 +301,15 @@ def forward(self, x, dropout_rate=DROPOUT_RATE):
idx_lookup = torch.reshape(sparse_features, [-1]) % self.vocab_size
embedding_table = torch.cat(self.embedding_table_chucks, dim=0)
embedded_sparse = embedding_table[idx_lookup]
- embedded_sparse = torch.reshape(embedded_sparse,
- [batch_size, -1, self.embed_dim])
+ embedded_sparse = torch.reshape(
+ embedded_sparse, [batch_size, -1, self.embed_dim]
+ )
if self.embed_ln:
embedded_sparse = self.embed_ln(embedded_sparse)
# Dot product interactions.
concatenated_dense = self.dot_interact(
- dense_features=embedded_dense, sparse_features=embedded_sparse)
+ dense_features=embedded_dense, sparse_features=embedded_sparse
+ )
# Final MLP.
logits = self.top_mlp(concatenated_dense, dropout_rate)
diff --git a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py
index 69f24c69d..74f91de43 100644
--- a/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py
+++ b/algoperf/workloads/criteo1tb/criteo1tb_pytorch/workload.py
@@ -7,24 +7,22 @@
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
-from algoperf import param_utils
-from algoperf import spec
+from algoperf import param_utils, spec
from algoperf.pytorch_utils import pytorch_setup
from algoperf.workloads.criteo1tb.criteo1tb_pytorch import models
-from algoperf.workloads.criteo1tb.workload import \
- BaseCriteo1TbDlrmSmallWorkload
+from algoperf.workloads.criteo1tb.workload import BaseCriteo1TbDlrmSmallWorkload
USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup()
class Criteo1TbDlrmSmallWorkload(BaseCriteo1TbDlrmSmallWorkload):
-
@property
def eval_batch_size(self) -> int:
return 8_192
def _per_example_sigmoid_binary_cross_entropy(
- self, logits: spec.Tensor, targets: spec.Tensor) -> spec.Tensor:
+ self, logits: spec.Tensor, targets: spec.Tensor
+ ) -> spec.Tensor:
ls = torch.nn.LogSigmoid()
log_p = ls(logits)
log_not_p = ls(-logits)
@@ -35,11 +33,12 @@ def _per_example_sigmoid_binary_cross_entropy(
# Does NOT apply regularization, which is left to the submitter to do in
# `update_params`.
def loss_fn(
- self,
- label_batch: spec.Tensor, # Dense (not one-hot) labels.
- logits_batch: spec.Tensor,
- mask_batch: Optional[spec.Tensor] = None,
- label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable
+ self,
+ label_batch: spec.Tensor, # Dense (not one-hot) labels.
+ logits_batch: spec.Tensor,
+ mask_batch: Optional[spec.Tensor] = None,
+ label_smoothing: float = 0.0,
+ ) -> Dict[str, spec.Tensor]: # differentiable
"""Evaluate the (masked) loss function at (label_batch, logits_batch).
Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of
@@ -51,7 +50,8 @@ def loss_fn(
label_batch = torch.reshape(label_batch, (batch_size,))
logits_batch = torch.reshape(logits_batch, (batch_size,))
per_example_losses = self._per_example_sigmoid_binary_cross_entropy(
- logits=logits_batch, targets=label_batch)
+ logits=logits_batch, targets=label_batch
+ )
if mask_batch is not None:
mask_batch = torch.reshape(mask_batch, (batch_size,))
per_example_losses *= mask_batch
@@ -60,9 +60,9 @@ def loss_fn(
n_valid_examples = len(per_example_losses)
summed_loss = per_example_losses.sum()
return {
- 'summed': summed_loss,
- 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE),
- 'per_example': per_example_losses,
+ 'summed': summed_loss,
+ 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE),
+ 'per_example': per_example_losses,
}
def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
@@ -74,13 +74,14 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
else:
model_class = models.DlrmSmall
model = model_class(
- vocab_size=self.vocab_size,
- num_dense_features=self.num_dense_features,
- mlp_bottom_dims=self.mlp_bottom_dims,
- mlp_top_dims=self.mlp_top_dims,
- embed_dim=self.embed_dim,
- use_layer_norm=self.use_layer_norm,
- embedding_init_multiplier=self.embedding_init_multiplier)
+ vocab_size=self.vocab_size,
+ num_dense_features=self.num_dense_features,
+ mlp_bottom_dims=self.mlp_bottom_dims,
+ mlp_top_dims=self.mlp_top_dims,
+ embed_dim=self.embed_dim,
+ use_layer_norm=self.use_layer_norm,
+ embedding_init_multiplier=self.embedding_init_multiplier,
+ )
self._param_shapes = param_utils.pytorch_param_shapes(model)
self._param_types = param_utils.pytorch_param_types(self._param_shapes)
model.to(DEVICE)
@@ -95,14 +96,14 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key in ['top_mlp.4.weight', 'top_mlp.4.bias']
def model_fn(
- self,
- params: spec.ParameterContainer,
- augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- mode: spec.ForwardPassMode,
- rng: spec.RandomState,
- update_batch_norm: bool,
- dropout_rate: float = models.DROPOUT_RATE
+ self,
+ params: spec.ParameterContainer,
+ augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ mode: spec.ForwardPassMode,
+ rng: spec.RandomState,
+ update_batch_norm: bool,
+ dropout_rate: float = models.DROPOUT_RATE,
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del rng
@@ -118,8 +119,8 @@ def model_fn(
model.train()
contexts = {
- spec.ForwardPassMode.EVAL: torch.no_grad,
- spec.ForwardPassMode.TRAIN: contextlib.nullcontext,
+ spec.ForwardPassMode.EVAL: torch.no_grad,
+ spec.ForwardPassMode.TRAIN: contextlib.nullcontext,
}
with contexts[mode]():
@@ -128,14 +129,15 @@ def model_fn(
return logits_batch, None
def _build_input_queue(
- self,
- data_rng: spec.RandomState,
- split: str,
- data_dir: str,
- global_batch_size: int,
- cache: Optional[bool] = None,
- repeat_final_dataset: Optional[bool] = None,
- num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]:
+ self,
+ data_rng: spec.RandomState,
+ split: str,
+ data_dir: str,
+ global_batch_size: int,
+ cache: Optional[bool] = None,
+ repeat_final_dataset: Optional[bool] = None,
+ num_batches: Optional[int] = None,
+ ) -> Iterator[Dict[str, spec.Tensor]]:
not_train = split != 'train'
per_device_batch_size = int(global_batch_size / N_GPUS)
@@ -143,35 +145,42 @@ def _build_input_queue(
# avoid creating too many threads.
if RANK == 0:
np_iter = super()._build_input_queue(
- data_rng=data_rng,
- split=split,
- data_dir=data_dir,
- global_batch_size=global_batch_size,
- num_batches=num_batches,
- repeat_final_dataset=repeat_final_dataset)
+ data_rng=data_rng,
+ split=split,
+ data_dir=data_dir,
+ global_batch_size=global_batch_size,
+ num_batches=num_batches,
+ repeat_final_dataset=repeat_final_dataset,
+ )
weights = None
while True:
if RANK == 0:
batch = next(np_iter) # pylint: disable=stop-iteration-return
inputs = torch.as_tensor(
- batch['inputs'], dtype=torch.float32, device=DEVICE)
+ batch['inputs'], dtype=torch.float32, device=DEVICE
+ )
targets = torch.as_tensor(
- batch['targets'], dtype=torch.float32, device=DEVICE)
+ batch['targets'], dtype=torch.float32, device=DEVICE
+ )
if not_train:
weights = batch.get('weights')
if weights is None:
- weights = torch.ones((N_GPUS, per_device_batch_size, 1),
- dtype=torch.float32,
- device=DEVICE)
+ weights = torch.ones(
+ (N_GPUS, per_device_batch_size, 1),
+ dtype=torch.float32,
+ device=DEVICE,
+ )
else:
weights = torch.as_tensor(
- weights, dtype=torch.float32, device=DEVICE)
+ weights, dtype=torch.float32, device=DEVICE
+ )
# Send batch to other devices when using DDP.
if USE_PYTORCH_DDP:
if not_train:
# During eval, the batch size of the remainder might be different.
per_device_batch_size = torch.tensor(
- len(targets[0]), dtype=torch.int32, device=DEVICE)
+ len(targets[0]), dtype=torch.int32, device=DEVICE
+ )
dist.broadcast(per_device_batch_size, src=0)
dist.broadcast(weights, src=0)
weights = weights[0]
@@ -187,52 +196,57 @@ def _build_input_queue(
else:
if not_train:
# During eval, the batch size of the remainder might be different.
- per_device_batch_size = torch.empty((1,),
- dtype=torch.int32,
- device=DEVICE)
+ per_device_batch_size = torch.empty(
+ (1,), dtype=torch.int32, device=DEVICE
+ )
dist.broadcast(per_device_batch_size, src=0)
- weights = torch.empty((N_GPUS, per_device_batch_size, 1),
- dtype=torch.float32,
- device=DEVICE)
+ weights = torch.empty(
+ (N_GPUS, per_device_batch_size, 1),
+ dtype=torch.float32,
+ device=DEVICE,
+ )
dist.broadcast(weights, src=0)
weights = weights[RANK]
- inputs = torch.empty((N_GPUS, per_device_batch_size, 39),
- dtype=torch.float32,
- device=DEVICE)
+ inputs = torch.empty(
+ (N_GPUS, per_device_batch_size, 39),
+ dtype=torch.float32,
+ device=DEVICE,
+ )
dist.broadcast(inputs, src=0)
inputs = inputs[RANK]
- targets = torch.empty((N_GPUS, per_device_batch_size, 1),
- dtype=torch.float32,
- device=DEVICE)
+ targets = torch.empty(
+ (N_GPUS, per_device_batch_size, 1), dtype=torch.float32, device=DEVICE
+ )
dist.broadcast(targets, src=0)
targets = targets[RANK]
if weights is None:
weights = torch.ones(per_device_batch_size, device=DEVICE)
batch = {
- 'inputs': inputs,
- 'targets': targets,
- 'weights': weights,
+ 'inputs': inputs,
+ 'targets': targets,
+ 'weights': weights,
}
yield batch
- def _eval_batch(self,
- params: spec.ParameterContainer,
- batch: Dict[str, spec.Tensor]) -> spec.Tensor:
+ def _eval_batch(
+ self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor]
+ ) -> spec.Tensor:
logits, _ = self.model_fn(
- params,
- batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=None,
- update_batch_norm=False)
+ params,
+ batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=None,
+ update_batch_norm=False,
+ )
weights = batch.get('weights')
if weights is None:
weights = torch.ones(len(logits), device=DEVICE)
summed_loss = self.loss_fn(
- label_batch=batch['targets'], logits_batch=logits,
- mask_batch=weights)['summed']
+ label_batch=batch['targets'], logits_batch=logits, mask_batch=weights
+ )['summed']
return summed_loss.to(dtype=torch.float64)
@@ -241,7 +255,6 @@ class Criteo1TbDlrmSmallTestWorkload(Criteo1TbDlrmSmallWorkload):
class Criteo1TbDlrmSmallLayerNormWorkload(Criteo1TbDlrmSmallWorkload):
-
@property
def use_layer_norm(self) -> bool:
"""Whether or not to use LayerNorm in the model."""
@@ -275,7 +288,6 @@ def test_target_value(self) -> float:
class Criteo1TbDlrmSmallEmbedInitWorkload(Criteo1TbDlrmSmallWorkload):
-
@property
def validation_target_value(self) -> float:
return 0.129657
diff --git a/algoperf/workloads/criteo1tb/input_pipeline.py b/algoperf/workloads/criteo1tb/input_pipeline.py
index 7e254336a..bce8b11c4 100644
--- a/algoperf/workloads/criteo1tb/input_pipeline.py
+++ b/algoperf/workloads/criteo1tb/input_pipeline.py
@@ -19,32 +19,32 @@
# Raw vocab sizes from
# https://cloud.google.com/tpu/docs/tutorials/dlrm-dcn-2.x#run-model.
_VOCAB_SIZES = [
- 39884406,
- 39043,
- 17289,
- 7420,
- 20263,
- 3,
- 7120,
- 1543,
- 63,
- 38532951,
- 2953546,
- 403346,
- 10,
- 2208,
- 11938,
- 155,
- 4,
- 976,
- 14,
- 39979771,
- 25641295,
- 39664984,
- 585935,
- 12972,
- 108,
- 36,
+ 39884406,
+ 39043,
+ 17289,
+ 7420,
+ 20263,
+ 3,
+ 7120,
+ 1543,
+ 63,
+ 38532951,
+ 2953546,
+ 403346,
+ 10,
+ 2208,
+ 11938,
+ 155,
+ 4,
+ 976,
+ 14,
+ 39979771,
+ 25641295,
+ 39664984,
+ 585935,
+ 12972,
+ 108,
+ 36,
]
@@ -60,7 +60,8 @@ def _parse_example_fn(num_dense_features, example):
categorical_defaults = [['00000000'] for _ in range(len(_VOCAB_SIZES))]
record_defaults = label_defaults + int_defaults + categorical_defaults
fields = tf.io.decode_csv(
- example, record_defaults, field_delim='\t', na_value='-1')
+ example, record_defaults, field_delim='\t', na_value='-1'
+ )
num_labels = 1
features = {}
@@ -78,20 +79,24 @@ def _parse_example_fn(num_dense_features, example):
# We append the column index to the string to make the same id in different
# columns unique.
cat_features.append(
- tf.strings.to_hash_bucket_fast(field + str(idx), _VOCAB_SIZES[idx]))
+ tf.strings.to_hash_bucket_fast(field + str(idx), _VOCAB_SIZES[idx])
+ )
cat_features = tf.cast(
- tf.stack(cat_features, axis=1), dtype=int_features.dtype)
+ tf.stack(cat_features, axis=1), dtype=int_features.dtype
+ )
features['inputs'] = tf.concat([int_features, cat_features], axis=1)
return features
-def get_criteo1tb_dataset(split: str,
- shuffle_rng,
- data_dir: str,
- num_dense_features: int,
- global_batch_size: int,
- num_batches: Optional[int] = None,
- repeat_final_dataset: bool = False):
+def get_criteo1tb_dataset(
+ split: str,
+ shuffle_rng,
+ data_dir: str,
+ num_dense_features: int,
+ global_batch_size: int,
+ num_batches: Optional[int] = None,
+ repeat_final_dataset: bool = False,
+):
"""Get the Criteo 1TB dataset for a given split."""
num_test_files = _NUM_DAY_23_FILES // 2 + 1
if split in ['train', 'eval_train']:
@@ -99,19 +104,20 @@ def get_criteo1tb_dataset(split: str,
elif split == 'validation':
# Assumes files are of the format day_23_04.
file_paths = [
- os.path.join(data_dir, f'day_23_{str(s).zfill(2)}')
- for s in range(num_test_files, _NUM_DAY_23_FILES)
+ os.path.join(data_dir, f'day_23_{str(s).zfill(2)}')
+ for s in range(num_test_files, _NUM_DAY_23_FILES)
]
else:
file_paths = [
- os.path.join(data_dir, f'day_23_{str(s).zfill(2)}')
- for s in range(0, num_test_files)
+ os.path.join(data_dir, f'day_23_{str(s).zfill(2)}')
+ for s in range(0, num_test_files)
]
is_training = split == 'train'
shuffle = is_training or split == 'eval_train'
ds = tf.data.Dataset.list_files(
- file_paths, shuffle=shuffle, seed=shuffle_rng[0])
+ file_paths, shuffle=shuffle, seed=shuffle_rng[0]
+ )
if shuffle:
ds = ds.shuffle(buffer_size=1024)
@@ -132,9 +138,10 @@ def get_criteo1tb_dataset(split: str,
ds = ds.repeat()
ds = map(
- functools.partial(
- data_utils.shard_and_maybe_pad_np,
- global_batch_size=global_batch_size),
- ds)
+ functools.partial(
+ data_utils.shard_and_maybe_pad_np, global_batch_size=global_batch_size
+ ),
+ ds,
+ )
return ds
diff --git a/algoperf/workloads/criteo1tb/workload.py b/algoperf/workloads/criteo1tb/workload.py
index 617b2e987..9fb819203 100644
--- a/algoperf/workloads/criteo1tb/workload.py
+++ b/algoperf/workloads/criteo1tb/workload.py
@@ -29,8 +29,9 @@ def target_metric_name(self) -> str:
"""The name of the target metric (useful for scoring/processing code)."""
return 'loss'
- def has_reached_validation_target(self, eval_result: Dict[str,
- float]) -> bool:
+ def has_reached_validation_target(
+ self, eval_result: Dict[str, float]
+ ) -> bool:
return eval_result['validation/loss'] < self.validation_target_value
@property
@@ -71,8 +72,9 @@ def num_eval_train_examples(self) -> int:
# Round up from num_validation_examples (which is the default for
# num_eval_train_examples) to the next multiple of eval_batch_size, so that
# we don't have to extract the correctly sized subset of the training data.
- rounded_up_multiple = math.ceil(self.num_validation_examples /
- self.eval_batch_size)
+ rounded_up_multiple = math.ceil(
+ self.num_validation_examples / self.eval_batch_size
+ )
return rounded_up_multiple * self.eval_batch_size
@property
@@ -100,23 +102,25 @@ def eval_period_time_sec(self) -> int:
return 2 * 60 # 2 mins.
def _build_input_queue(
- self,
- data_rng: spec.RandomState,
- split: str,
- data_dir: str,
- global_batch_size: int,
- cache: Optional[bool] = None,
- repeat_final_dataset: Optional[bool] = None,
- num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]:
+ self,
+ data_rng: spec.RandomState,
+ split: str,
+ data_dir: str,
+ global_batch_size: int,
+ cache: Optional[bool] = None,
+ repeat_final_dataset: Optional[bool] = None,
+ num_batches: Optional[int] = None,
+ ) -> Iterator[Dict[str, spec.Tensor]]:
del cache
ds = input_pipeline.get_criteo1tb_dataset(
- split=split,
- shuffle_rng=data_rng,
- data_dir=data_dir,
- num_dense_features=self.num_dense_features,
- global_batch_size=global_batch_size,
- num_batches=num_batches,
- repeat_final_dataset=repeat_final_dataset)
+ split=split,
+ shuffle_rng=data_rng,
+ data_dir=data_dir,
+ num_dense_features=self.num_dense_features,
+ global_batch_size=global_batch_size,
+ num_batches=num_batches,
+ repeat_final_dataset=repeat_final_dataset,
+ )
for batch in iter(ds):
yield batch
@@ -126,15 +130,17 @@ def step_hint(self) -> int:
"""Approx. steps the baseline can do in the allowed runtime budget."""
return 10_666
- def _eval_model_on_split(self,
- split: str,
- num_examples: int,
- global_batch_size: int,
- params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- rng: spec.RandomState,
- data_dir: str,
- global_step: int = 0) -> Dict[str, float]:
+ def _eval_model_on_split(
+ self,
+ split: str,
+ num_examples: int,
+ global_batch_size: int,
+ params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ rng: spec.RandomState,
+ data_dir: str,
+ global_step: int = 0,
+ ) -> Dict[str, float]:
"""Run a full evaluation of the model."""
del model_state
del global_step
@@ -142,12 +148,13 @@ def _eval_model_on_split(self,
if split not in self._eval_iters:
# These iterators will repeat indefinitely.
self._eval_iters[split] = self._build_input_queue(
- data_rng=rng,
- split=split,
- data_dir=data_dir,
- global_batch_size=global_batch_size,
- num_batches=num_batches,
- repeat_final_dataset=True)
+ data_rng=rng,
+ split=split,
+ data_dir=data_dir,
+ global_batch_size=global_batch_size,
+ num_batches=num_batches,
+ repeat_final_dataset=True,
+ )
loss = 0.0
for _ in range(num_batches):
eval_batch = next(self._eval_iters[split])
diff --git a/algoperf/workloads/fastmri/fastmri_jax/models.py b/algoperf/workloads/fastmri/fastmri_jax/models.py
index a5fe060b9..b80c370ea 100644
--- a/algoperf/workloads/fastmri/fastmri_jax/models.py
+++ b/algoperf/workloads/fastmri/fastmri_jax/models.py
@@ -12,6 +12,7 @@
Data:
github.com/facebookresearch/fastMRI/tree/main/fastmri/data
"""
+
import functools
import flax.linen as nn
@@ -31,7 +32,7 @@ def _instance_norm2d(x, axes, epsilon=1e-5):
mean2 = jnp.mean(jnp.square(x), axes)
# mean2 - _abs_sq(mean) is not guaranteed to be non-negative due
# to floating point round-off errors.
- var = jnp.maximum(0., mean2 - jnp.square(mean))
+ var = jnp.maximum(0.0, mean2 - jnp.square(mean))
stats_shape = list(x.shape)
for axis in axes:
stats_shape[axis] = 1
@@ -46,16 +47,17 @@ def _instance_norm2d(x, axes, epsilon=1e-5):
class UNet(nn.Module):
"""Jax / Flax implementation of a U-Net model.
- O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks
- for biomedical image segmentation. In International Conference on Medical
- image computing and computer-assisted intervention, pages 234–241.
- Springer, 2015.
+ O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks
+ for biomedical image segmentation. In International Conference on Medical
+ image computing and computer-assisted intervention, pages 234–241.
+ Springer, 2015.
- out_channels: Number of channels in the output to the U-Net model.
- channels: Number of output channels of the first convolution layer.
- num_pool_layers: Number of down-sampling and up-sampling layers.
- dropout_rate: Dropout probability.
+ out_channels: Number of channels in the output to the U-Net model.
+ channels: Number of output channels of the first convolution layer.
+ num_pool_layers: Number of down-sampling and up-sampling layers.
+ dropout_rate: Dropout probability.
"""
+
num_channels: int = 32
num_pool_layers: int = 4
out_channels = 1
@@ -67,14 +69,16 @@ class UNet(nn.Module):
def __call__(self, x, train=True, dropout_rate=DROPOUT_RATE):
# pylint: disable=invalid-name
_ConvBlock = functools.partial(
- ConvBlock,
- dropout_rate=dropout_rate,
- use_tanh=self.use_tanh,
- use_layer_norm=self.use_layer_norm)
+ ConvBlock,
+ dropout_rate=dropout_rate,
+ use_tanh=self.use_tanh,
+ use_layer_norm=self.use_layer_norm,
+ )
_TransposeConvBlock = functools.partial(
- TransposeConvBlock,
- use_tanh=self.use_tanh,
- use_layer_norm=self.use_layer_norm)
+ TransposeConvBlock,
+ use_tanh=self.use_tanh,
+ use_layer_norm=self.use_layer_norm,
+ )
down_sample_layers = [_ConvBlock(self.num_channels)]
@@ -125,9 +129,9 @@ def __call__(self, x, train=True, dropout_rate=DROPOUT_RATE):
output = jnp.concatenate((output, downsample_layer), axis=-1)
output = conv(output, train)
- output = nn.Conv(
- self.out_channels, kernel_size=(1, 1), strides=(1, 1))(
- output)
+ output = nn.Conv(self.out_channels, kernel_size=(1, 1), strides=(1, 1))(
+ output
+ )
return output.squeeze(-1)
@@ -136,6 +140,7 @@ class ConvBlock(nn.Module):
out_channels: Number of channels in the output.
dropout_rate: Dropout probability.
"""
+
out_channels: int
use_tanh: bool
use_layer_norm: bool
@@ -152,11 +157,11 @@ def __call__(self, x, train=True, dropout_rate=DROPOUT_RATE):
jnp.array: Output tensor of shape `(N, H, W, out_channels)`.
"""
x = nn.Conv(
- features=self.out_channels,
- kernel_size=(3, 3),
- strides=(1, 1),
- use_bias=False)(
- x)
+ features=self.out_channels,
+ kernel_size=(3, 3),
+ strides=(1, 1),
+ use_bias=False,
+ )(x)
if self.use_layer_norm:
x = nn.LayerNorm(reduction_axes=(1, 2, 3))(x)
else:
@@ -171,23 +176,23 @@ def __call__(self, x, train=True, dropout_rate=DROPOUT_RATE):
x = activation_fn(x)
# Ref code uses dropout2d which applies the same mask for the entire channel
# Replicated by using broadcast dims to have the same filter on HW
- x = Dropout(
- dropout_rate, broadcast_dims=(1, 2), deterministic=not train)(
- x, rate=dropout_rate)
+ x = Dropout(dropout_rate, broadcast_dims=(1, 2), deterministic=not train)(
+ x, rate=dropout_rate
+ )
x = nn.Conv(
- features=self.out_channels,
- kernel_size=(3, 3),
- strides=(1, 1),
- use_bias=False)(
- x)
+ features=self.out_channels,
+ kernel_size=(3, 3),
+ strides=(1, 1),
+ use_bias=False,
+ )(x)
if self.use_layer_norm:
x = nn.LayerNorm(reduction_axes=(1, 2, 3))(x)
else:
x = _instance_norm2d(x, (1, 2))
x = activation_fn(x)
- x = Dropout(
- dropout_rate, broadcast_dims=(1, 2), deterministic=not train)(
- x, rate=dropout_rate)
+ x = Dropout(dropout_rate, broadcast_dims=(1, 2), deterministic=not train)(
+ x, rate=dropout_rate
+ )
return x
@@ -195,6 +200,7 @@ class TransposeConvBlock(nn.Module):
"""A Transpose Convolutional Block.
out_channels: Number of channels in the output.
"""
+
out_channels: int
use_tanh: bool
use_layer_norm: bool
@@ -208,8 +214,8 @@ def __call__(self, x):
jnp.array: Output tensor of shape `(N, H*2, W*2, out_channels)`.
"""
x = nn.ConvTranspose(
- self.out_channels, kernel_size=(2, 2), strides=(2, 2), use_bias=False)(
- x)
+ self.out_channels, kernel_size=(2, 2), strides=(2, 2), use_bias=False
+ )(x)
x = _instance_norm2d(x, (1, 2))
if self.use_tanh:
activation_fn = nn.tanh
diff --git a/algoperf/workloads/fastmri/fastmri_jax/ssim.py b/algoperf/workloads/fastmri/fastmri_jax/ssim.py
index e15b93616..ca2ee1b60 100644
--- a/algoperf/workloads/fastmri/fastmri_jax/ssim.py
+++ b/algoperf/workloads/fastmri/fastmri_jax/ssim.py
@@ -49,12 +49,9 @@ def ssim(logits, targets, mean=None, std=None, volume_max=None):
return ssims
-def structural_similarity(im1,
- im2,
- data_range=1.0,
- win_size=7,
- k1=0.01,
- k2=0.03):
+def structural_similarity(
+ im1, im2, data_range=1.0, win_size=7, k1=0.01, k2=0.03
+):
"""Compute the mean structural similarity index between two images.
NOTE(dsuo): modified from skimage.metrics.structural_similarity.
@@ -85,7 +82,7 @@ def structural_similarity(im1,
"""
filter_func = functools.partial(_uniform_filter, size=win_size)
- num_points = win_size**len(im1.shape)
+ num_points = win_size ** len(im1.shape)
# filter has already normalized by num_points
cov_norm = num_points / (num_points - 1) # sample covariance
@@ -102,8 +99,8 @@ def structural_similarity(im1,
vy = cov_norm * (uyy - uy * uy)
vxy = cov_norm * (uxy - ux * uy)
- c1 = (k1 * data_range)**2
- c2 = (k2 * data_range)**2
+ c1 = (k1 * data_range) ** 2
+ c2 = (k2 * data_range) ** 2
a1 = 2 * ux * uy + c1
a2 = 2 * vxy + c2
@@ -121,12 +118,15 @@ def structural_similarity(im1,
def _uniform_filter(im, size=7):
-
def conv(im):
- return jnp.convolve(
+ return (
+ jnp.convolve(
jnp.pad(im, pad_width=size // 2, mode='symmetric'),
jnp.ones(size),
- mode='valid') / size
+ mode='valid',
+ )
+ / size
+ )
im = jax.vmap(conv, (0,))(im)
im = jax.vmap(conv, (1,))(im)
diff --git a/algoperf/workloads/fastmri/fastmri_jax/workload.py b/algoperf/workloads/fastmri/fastmri_jax/workload.py
index bd0aa1d0b..08bb25014 100644
--- a/algoperf/workloads/fastmri/fastmri_jax/workload.py
+++ b/algoperf/workloads/fastmri/fastmri_jax/workload.py
@@ -4,32 +4,29 @@
import math
from typing import Dict, Optional, Tuple
-from flax import jax_utils
import jax
import jax.numpy as jnp
+from flax import jax_utils
-from algoperf import param_utils
-from algoperf import spec
import algoperf.random_utils as prng
-from algoperf.workloads.fastmri.fastmri_jax.models import DROPOUT_RATE
-from algoperf.workloads.fastmri.fastmri_jax.models import UNet
+from algoperf import param_utils, spec
+from algoperf.workloads.fastmri.fastmri_jax.models import DROPOUT_RATE, UNet
from algoperf.workloads.fastmri.fastmri_jax.ssim import ssim
from algoperf.workloads.fastmri.workload import BaseFastMRIWorkload
class FastMRIWorkload(BaseFastMRIWorkload):
-
def init_model_fn(
- self,
- rng: spec.RandomState,
+ self,
+ rng: spec.RandomState,
) -> spec.ModelInitState:
"""aux_dropout_rate is unused."""
fake_batch = jnp.zeros((13, 320, 320))
self._model = UNet(
- num_pool_layers=self.num_pool_layers,
- num_channels=self.num_channels,
- use_tanh=self.use_tanh,
- use_layer_norm=self.use_layer_norm,
+ num_pool_layers=self.num_pool_layers,
+ num_channels=self.num_channels,
+ use_tanh=self.use_tanh,
+ use_layer_norm=self.use_layer_norm,
)
params_rng, _ = jax.random.split(rng)
@@ -45,34 +42,37 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key == 'Conv_0'
def model_fn(
- self,
- params: spec.ParameterContainer,
- augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- mode: spec.ForwardPassMode,
- rng: spec.RandomState,
- update_batch_norm: bool,
- dropout_rate: float = DROPOUT_RATE
+ self,
+ params: spec.ParameterContainer,
+ augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ mode: spec.ForwardPassMode,
+ rng: spec.RandomState,
+ update_batch_norm: bool,
+ dropout_rate: float = DROPOUT_RATE,
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del update_batch_norm
train = mode == spec.ForwardPassMode.TRAIN
- logits = self._model.apply({'params': params},
- augmented_and_preprocessed_input_batch['inputs'],
- rngs={'dropout': rng},
- train=train,
- dropout_rate=dropout_rate)
+ logits = self._model.apply(
+ {'params': params},
+ augmented_and_preprocessed_input_batch['inputs'],
+ rngs={'dropout': rng},
+ train=train,
+ dropout_rate=dropout_rate,
+ )
return logits, None
# Does NOT apply regularization, which is left to the submitter to do in
# `update_params`.
def loss_fn(
- self,
- label_batch: spec.Tensor, # Dense or one-hot labels.
- logits_batch: spec.Tensor,
- mask_batch: Optional[spec.Tensor] = None,
- label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable
+ self,
+ label_batch: spec.Tensor, # Dense or one-hot labels.
+ logits_batch: spec.Tensor,
+ mask_batch: Optional[spec.Tensor] = None,
+ label_smoothing: float = 0.0,
+ ) -> Dict[str, spec.Tensor]: # differentiable
"""Evaluate the (masked) loss function at (label_batch, logits_batch).
Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of
@@ -81,8 +81,9 @@ def loss_fn(
"""
del label_smoothing
per_example_losses = jnp.mean(
- jnp.abs(logits_batch - label_batch),
- axis=tuple(range(1, logits_batch.ndim)))
+ jnp.abs(logits_batch - label_batch),
+ axis=tuple(range(1, logits_batch.ndim)),
+ )
# mask_batch is assumed to be shape [batch].
if mask_batch is not None:
per_example_losses *= mask_batch
@@ -91,56 +92,63 @@ def loss_fn(
n_valid_examples = len(per_example_losses)
summed_loss = per_example_losses.sum()
return {
- 'summed': summed_loss,
- 'n_valid_examples': n_valid_examples,
- 'per_example': per_example_losses,
+ 'summed': summed_loss,
+ 'n_valid_examples': n_valid_examples,
+ 'per_example': per_example_losses,
}
@functools.partial(
- jax.pmap,
- axis_name='batch',
- in_axes=(None, 0, 0, 0),
- static_broadcasted_argnums=(0,))
- def _eval_model(self,
- params: spec.Tensor,
- batch: Dict[str, spec.Tensor],
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+ jax.pmap,
+ axis_name='batch',
+ in_axes=(None, 0, 0, 0),
+ static_broadcasted_argnums=(0,),
+ )
+ def _eval_model(
+ self,
+ params: spec.Tensor,
+ batch: Dict[str, spec.Tensor],
+ rng: spec.RandomState,
+ ) -> Dict[str, spec.Tensor]:
"""Return the SSIM and loss as a dict."""
logits, _ = self.model_fn(
- params,
- batch,
- model_state=None,
- mode=spec.ForwardPassMode.EVAL,
- rng=rng,
- update_batch_norm=False)
+ params,
+ batch,
+ model_state=None,
+ mode=spec.ForwardPassMode.EVAL,
+ rng=rng,
+ update_batch_norm=False,
+ )
targets = batch['targets']
weights = batch.get('weights')
if weights is None:
weights = jnp.ones(len(logits))
ssim_vals = ssim(
- logits,
- targets,
- mean=batch['mean'],
- std=batch['std'],
- volume_max=batch['volume_max'])
+ logits,
+ targets,
+ mean=batch['mean'],
+ std=batch['std'],
+ volume_max=batch['volume_max'],
+ )
ssim_sum = jnp.sum(ssim_vals * weights)
summed_loss = self.loss_fn(targets, logits, weights)['summed']
metrics = {
- 'ssim': ssim_sum,
- 'loss': summed_loss,
+ 'ssim': ssim_sum,
+ 'loss': summed_loss,
}
metrics = jax.lax.psum(metrics, axis_name='batch')
return metrics
- def _eval_model_on_split(self,
- split: str,
- num_examples: int,
- global_batch_size: int,
- params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- rng: spec.RandomState,
- data_dir: str,
- global_step: int = 0) -> Dict[str, float]:
+ def _eval_model_on_split(
+ self,
+ split: str,
+ num_examples: int,
+ global_batch_size: int,
+ params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ rng: spec.RandomState,
+ data_dir: str,
+ global_step: int = 0,
+ ) -> Dict[str, float]:
"""Run a full evaluation of the model."""
del model_state
del global_step
@@ -149,27 +157,27 @@ def _eval_model_on_split(self,
if split not in self._eval_iters:
# These iterators repeat indefinitely.
self._eval_iters[split] = self._build_input_queue(
- data_rng,
- split,
- data_dir,
- global_batch_size=global_batch_size,
- repeat_final_dataset=True,
- num_batches=num_batches)
-
- total_metrics = {'ssim': 0., 'loss': 0.}
+ data_rng,
+ split,
+ data_dir,
+ global_batch_size=global_batch_size,
+ repeat_final_dataset=True,
+ num_batches=num_batches,
+ )
+
+ total_metrics = {'ssim': 0.0, 'loss': 0.0}
eval_rngs = prng.split(model_rng, jax.local_device_count())
for _ in range(num_batches):
batch = next(self._eval_iters[split])
# We already sum these metrics across devices inside _eval_model.
synced_metrics = self._eval_model(params, batch, eval_rngs)
total_metrics = {
- k: v + synced_metrics[k][0] for k, v in total_metrics.items()
+ k: v + synced_metrics[k][0] for k, v in total_metrics.items()
}
return {k: float(v.item() / num_examples) for k, v in total_metrics.items()}
class FastMRIModelSizeWorkload(FastMRIWorkload):
-
@property
def num_pool_layers(self) -> bool:
"""Whether or not to use tanh activations in the model."""
@@ -190,7 +198,6 @@ def test_target_value(self) -> float:
class FastMRITanhWorkload(FastMRIWorkload):
-
@property
def use_tanh(self) -> bool:
"""Whether or not to use tanh activations in the model."""
@@ -206,7 +213,6 @@ def test_target_value(self) -> float:
class FastMRILayerNormWorkload(FastMRIWorkload):
-
@property
def use_layer_norm(self) -> bool:
"""Whether or not to use tanh activations in the model."""
diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/models.py b/algoperf/workloads/fastmri/fastmri_pytorch/models.py
index 8441f06c2..16cf8bd54 100644
--- a/algoperf/workloads/fastmri/fastmri_pytorch/models.py
+++ b/algoperf/workloads/fastmri/fastmri_pytorch/models.py
@@ -7,31 +7,31 @@
from functools import partial
import torch
-from torch import nn
-from torch import Tensor
+from torch import Tensor, nn
from torch.nn import functional as F
from algoperf import init_utils
-from algoperf.pytorch_utils import CustomDropout2d
-from algoperf.pytorch_utils import SequentialWithDropout
+from algoperf.pytorch_utils import CustomDropout2d, SequentialWithDropout
DROPOUT_RATE = 0.0
class UNet(nn.Module):
r"""U-Net model from
- `"U-net: Convolutional networks
- for biomedical image segmentation"
- `_.
- """
-
- def __init__(self,
- in_chans: int = 1,
- out_chans: int = 1,
- num_channels: int = 32,
- num_pool_layers: int = 4,
- use_tanh: bool = False,
- use_layer_norm: bool = False) -> None:
+ `"U-net: Convolutional networks
+ for biomedical image segmentation"
+ `_.
+ """
+
+ def __init__(
+ self,
+ in_chans: int = 1,
+ out_chans: int = 1,
+ num_channels: int = 32,
+ num_pool_layers: int = 4,
+ use_tanh: bool = False,
+ use_layer_norm: bool = False,
+ ) -> None:
super().__init__()
self.in_chans = in_chans
@@ -40,11 +40,13 @@ def __init__(self,
self.num_pool_layers = num_pool_layers
self.down_sample_layers = nn.ModuleList(
- [ConvBlock(in_chans, num_channels, use_tanh, use_layer_norm)])
+ [ConvBlock(in_chans, num_channels, use_tanh, use_layer_norm)]
+ )
ch = num_channels
for _ in range(num_pool_layers - 1):
self.down_sample_layers.append(
- ConvBlock(ch, ch * 2, use_tanh, use_layer_norm))
+ ConvBlock(ch, ch * 2, use_tanh, use_layer_norm)
+ )
ch *= 2
self.conv = ConvBlock(ch, ch * 2, use_tanh, use_layer_norm)
@@ -53,24 +55,26 @@ def __init__(self,
for _ in range(num_pool_layers - 1):
self.up_transpose_conv.append(
- TransposeConvBlock(ch * 2, ch, use_tanh, use_layer_norm))
+ TransposeConvBlock(ch * 2, ch, use_tanh, use_layer_norm)
+ )
self.up_conv.append(ConvBlock(ch * 2, ch, use_tanh, use_layer_norm))
ch //= 2
self.up_transpose_conv.append(
- TransposeConvBlock(ch * 2, ch, use_tanh, use_layer_norm))
+ TransposeConvBlock(ch * 2, ch, use_tanh, use_layer_norm)
+ )
self.up_conv.append(
- SequentialWithDropout(
- ConvBlock(ch * 2, ch, use_tanh, use_layer_norm),
- nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1),
- ))
+ SequentialWithDropout(
+ ConvBlock(ch * 2, ch, use_tanh, use_layer_norm),
+ nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1),
+ )
+ )
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
init_utils.pytorch_default_init(m)
def forward(self, x: Tensor, dropout_rate: float = DROPOUT_RATE) -> Tensor:
-
stack = []
output = x
@@ -95,7 +99,7 @@ def forward(self, x: Tensor, dropout_rate: float = DROPOUT_RATE) -> Tensor:
if output.shape[-2] != downsample_layer.shape[-2]:
padding[3] = 1 # padding bottom
if torch.sum(torch.tensor(padding)) != 0:
- output = F.pad(output, padding, "reflect")
+ output = F.pad(output, padding, 'reflect')
output = torch.cat([output, downsample_layer], dim=1)
output = conv(output, dropout_rate)
@@ -107,11 +111,9 @@ class ConvBlock(nn.Module):
# A Convolutional Block that consists of two convolution layers each
# followed by instance normalization, LeakyReLU activation and dropout_rate.
- def __init__(self,
- in_chans: int,
- out_chans: int,
- use_tanh: bool,
- use_layer_norm: bool) -> None:
+ def __init__(
+ self, in_chans: int, out_chans: int, use_tanh: bool, use_layer_norm: bool
+ ) -> None:
super().__init__()
self._supports_custom_dropout = True
@@ -124,14 +126,14 @@ def __init__(self,
else:
activation_fn = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.conv_layers = SequentialWithDropout(
- nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False),
- norm_layer(out_chans),
- activation_fn,
- CustomDropout2d(),
- nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False),
- norm_layer(out_chans),
- activation_fn,
- CustomDropout2d(),
+ nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False),
+ norm_layer(out_chans),
+ activation_fn,
+ CustomDropout2d(),
+ nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False),
+ norm_layer(out_chans),
+ activation_fn,
+ CustomDropout2d(),
)
def forward(self, x: Tensor, dropout_rate: float) -> Tensor:
@@ -143,11 +145,11 @@ class TransposeConvBlock(nn.Module):
# layers followed by instance normalization and LeakyReLU activation.
def __init__(
- self,
- in_chans: int,
- out_chans: int,
- use_tanh: bool,
- use_layer_norm: bool,
+ self,
+ in_chans: int,
+ out_chans: int,
+ use_tanh: bool,
+ use_layer_norm: bool,
):
super().__init__()
if use_tanh:
@@ -155,10 +157,11 @@ def __init__(
else:
activation_fn = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.layers = nn.Sequential(
- nn.ConvTranspose2d(
- in_chans, out_chans, kernel_size=2, stride=2, bias=False),
- nn.InstanceNorm2d(out_chans),
- activation_fn,
+ nn.ConvTranspose2d(
+ in_chans, out_chans, kernel_size=2, stride=2, bias=False
+ ),
+ nn.InstanceNorm2d(out_chans),
+ activation_fn,
)
def forward(self, x: Tensor) -> Tensor:
diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/ssim.py b/algoperf/workloads/fastmri/fastmri_pytorch/ssim.py
index 45b61bea4..7d594b959 100644
--- a/algoperf/workloads/fastmri/fastmri_pytorch/ssim.py
+++ b/algoperf/workloads/fastmri/fastmri_pytorch/ssim.py
@@ -32,9 +32,9 @@ def ssim(logits, targets, mean=None, std=None, volume_max=None):
# NOTE(dsuo): `volume_max` can be 0 if we have a padded batch, but this will
# lead to NaN values in `ssim`.
- volume_max = torch.where(volume_max == 0,
- torch.ones_like(volume_max),
- volume_max)
+ volume_max = torch.where(
+ volume_max == 0, torch.ones_like(volume_max), volume_max
+ )
if mean is None:
mean = torch.zeros(logits.shape[0], device=DEVICE)
@@ -56,12 +56,9 @@ def ssim(logits, targets, mean=None, std=None, volume_max=None):
return ssims
-def structural_similarity(im1,
- im2,
- data_range=1.0,
- win_size=7,
- k1=0.01,
- k2=0.03):
+def structural_similarity(
+ im1, im2, data_range=1.0, win_size=7, k1=0.01, k2=0.03
+):
"""Compute the mean structural similarity index between two images.
NOTE(dsuo): modified from skimage.metrics.structural_similarity.
@@ -92,7 +89,7 @@ def structural_similarity(im1,
"""
filter_func = functools.partial(_uniform_filter, size=win_size)
- num_points = win_size**len(im1.shape)
+ num_points = win_size ** len(im1.shape)
# filter has already normalized by num_points
cov_norm = num_points / (num_points - 1) # sample covariance
@@ -109,8 +106,8 @@ def structural_similarity(im1,
vy = cov_norm * (uyy - uy * uy)
vxy = cov_norm * (uxy - ux * uy)
- c1 = (k1 * data_range)**2
- c2 = (k2 * data_range)**2
+ c1 = (k1 * data_range) ** 2
+ c2 = (k2 * data_range) ** 2
a1 = 2 * ux * uy + c1
a2 = 2 * vxy + c2
diff --git a/algoperf/workloads/fastmri/fastmri_pytorch/workload.py b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py
index 1adbb57ca..bddf6b1f3 100644
--- a/algoperf/workloads/fastmri/fastmri_pytorch/workload.py
+++ b/algoperf/workloads/fastmri/fastmri_pytorch/workload.py
@@ -9,10 +9,8 @@
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
-from algoperf import param_utils
-from algoperf import pytorch_utils
-from algoperf import spec
import algoperf.random_utils as prng
+from algoperf import param_utils, pytorch_utils, spec
from algoperf.workloads.fastmri.fastmri_pytorch import models
from algoperf.workloads.fastmri.fastmri_pytorch.models import UNet
from algoperf.workloads.fastmri.fastmri_pytorch.ssim import ssim
@@ -22,28 +20,31 @@
class FastMRIWorkload(BaseFastMRIWorkload):
-
- def _build_input_queue(self,
- data_rng: spec.RandomState,
- split: str,
- data_dir: str,
- global_batch_size: int,
- cache: Optional[bool] = None,
- repeat_final_dataset: Optional[bool] = None,
- num_batches: Optional[int] = None):
+ def _build_input_queue(
+ self,
+ data_rng: spec.RandomState,
+ split: str,
+ data_dir: str,
+ global_batch_size: int,
+ cache: Optional[bool] = None,
+ repeat_final_dataset: Optional[bool] = None,
+ num_batches: Optional[int] = None,
+ ):
per_device_batch_size = int(global_batch_size / N_GPUS)
# Only create and iterate over tf input pipeline in one Python process to
# avoid creating too many threads.
if RANK == 0:
data_rng = data_rng.astype('uint32')
- np_iter = super()._build_input_queue(data_rng,
- split,
- data_dir,
- global_batch_size,
- cache,
- repeat_final_dataset,
- num_batches)
+ np_iter = super()._build_input_queue(
+ data_rng,
+ split,
+ data_dir,
+ global_batch_size,
+ cache,
+ repeat_final_dataset,
+ num_batches,
+ )
while True:
if RANK == 0:
@@ -59,20 +60,23 @@ def _build_input_queue(self,
else:
aux_tensor_list.append(tensor)
batch[key] = (
- tensor[0] if USE_PYTORCH_DDP else tensor.view(
- -1, *value.shape[2:]))
+ tensor[0] if USE_PYTORCH_DDP else tensor.view(-1, *value.shape[2:])
+ )
# Send batch to other devices when using DDP.
if USE_PYTORCH_DDP:
if split != 'train':
# During eval, the batch size of the remainder might be different.
per_device_batch_size = torch.tensor(
- len(batch['inputs']), dtype=torch.int32, device=DEVICE)
+ len(batch['inputs']), dtype=torch.int32, device=DEVICE
+ )
dist.broadcast(per_device_batch_size, src=0)
weights = weights if 'weights' in batch else None
if weights is None:
- weights = torch.ones((N_GPUS, per_device_batch_size),
- dtype=torch.float64,
- device=DEVICE)
+ weights = torch.ones(
+ (N_GPUS, per_device_batch_size),
+ dtype=torch.float64,
+ device=DEVICE,
+ )
# Has no effect, but without it `batch` has no `weights` key
# for RANK == 0, but has one for all others.
batch['weights'] = weights[0]
@@ -83,20 +87,22 @@ def _build_input_queue(self,
batch = {}
if split != 'train':
# During eval, the batch size of the remainder might be different.
- per_device_batch_size = torch.empty((),
- dtype=torch.int32,
- device=DEVICE)
+ per_device_batch_size = torch.empty(
+ (), dtype=torch.int32, device=DEVICE
+ )
dist.broadcast(per_device_batch_size, src=0)
- weights = torch.empty((N_GPUS, per_device_batch_size),
- dtype=torch.float64,
- device=DEVICE)
+ weights = torch.empty(
+ (N_GPUS, per_device_batch_size), dtype=torch.float64, device=DEVICE
+ )
dist.broadcast(weights, src=0)
batch['weights'] = weights[RANK]
- tensors = torch.empty((2, N_GPUS, per_device_batch_size, 320, 320),
- device=DEVICE)
+ tensors = torch.empty(
+ (2, N_GPUS, per_device_batch_size, 320, 320), device=DEVICE
+ )
dist.broadcast(tensors, src=0)
- aux_tensors = torch.empty((3, N_GPUS, per_device_batch_size),
- device=DEVICE)
+ aux_tensors = torch.empty(
+ (3, N_GPUS, per_device_batch_size), device=DEVICE
+ )
dist.broadcast(aux_tensors, src=0)
# Note that the batch dict in the RANK == 0 process is ordered.
batch['inputs'] = tensors[0][RANK]
@@ -109,10 +115,11 @@ def _build_input_queue(self,
def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
torch.random.manual_seed(rng[0])
model = UNet(
- num_pool_layers=self.num_pool_layers,
- num_channels=self.num_channels,
- use_tanh=self.use_tanh,
- use_layer_norm=self.use_layer_norm)
+ num_pool_layers=self.num_pool_layers,
+ num_channels=self.num_channels,
+ use_tanh=self.use_tanh,
+ use_layer_norm=self.use_layer_norm,
+ )
self._param_shapes = param_utils.pytorch_param_shapes(model)
self._param_types = param_utils.pytorch_param_types(self._param_shapes)
model.to(DEVICE)
@@ -127,14 +134,14 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key in ['up_conv.3.1.weight', 'up_conv.3.1.bias']
def model_fn(
- self,
- params: spec.ParameterContainer,
- augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- mode: spec.ForwardPassMode,
- rng: spec.RandomState,
- update_batch_norm: bool,
- dropout_rate: float = models.DROPOUT_RATE
+ self,
+ params: spec.ParameterContainer,
+ augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ mode: spec.ForwardPassMode,
+ rng: spec.RandomState,
+ update_batch_norm: bool,
+ dropout_rate: float = models.DROPOUT_RATE,
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del rng
@@ -149,25 +156,27 @@ def model_fn(
model.train()
contexts = {
- spec.ForwardPassMode.EVAL: torch.no_grad,
- spec.ForwardPassMode.TRAIN: contextlib.nullcontext,
+ spec.ForwardPassMode.EVAL: torch.no_grad,
+ spec.ForwardPassMode.TRAIN: contextlib.nullcontext,
}
with contexts[mode]():
logit_batch = model(
- augmented_and_preprocessed_input_batch['inputs'].unsqueeze(1),
- dropout_rate=dropout_rate).squeeze(1)
+ augmented_and_preprocessed_input_batch['inputs'].unsqueeze(1),
+ dropout_rate=dropout_rate,
+ ).squeeze(1)
return logit_batch, None
# Does NOT apply regularization, which is left to the submitter to do in
# `update_params`.
def loss_fn(
- self,
- label_batch: spec.Tensor, # Dense or one-hot labels.
- logits_batch: spec.Tensor,
- mask_batch: Optional[spec.Tensor] = None,
- label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable
+ self,
+ label_batch: spec.Tensor, # Dense or one-hot labels.
+ logits_batch: spec.Tensor,
+ mask_batch: Optional[spec.Tensor] = None,
+ label_smoothing: float = 0.0,
+ ) -> Dict[str, spec.Tensor]: # differentiable
"""Evaluate the (masked) loss function at (label_batch, logits_batch).
Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of
@@ -176,7 +185,8 @@ def loss_fn(
"""
del label_smoothing
per_example_losses = F.l1_loss(
- logits_batch, label_batch, reduction='none').mean(dim=(1, 2))
+ logits_batch, label_batch, reduction='none'
+ ).mean(dim=(1, 2))
# mask_batch is assumed to be shape [batch].
if mask_batch is not None:
per_example_losses *= mask_batch
@@ -185,46 +195,52 @@ def loss_fn(
n_valid_examples = len(per_example_losses)
summed_loss = per_example_losses.sum()
return {
- 'summed': summed_loss,
- 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE),
- 'per_example': per_example_losses,
+ 'summed': summed_loss,
+ 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE),
+ 'per_example': per_example_losses,
}
- def _eval_model(self,
- params: spec.ParameterContainer,
- batch: Dict[str, spec.Tensor],
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+ def _eval_model(
+ self,
+ params: spec.ParameterContainer,
+ batch: Dict[str, spec.Tensor],
+ rng: spec.RandomState,
+ ) -> Dict[str, spec.Tensor]:
"""Return the SSIM and loss as a dict."""
outputs, _ = self.model_fn(
- params,
- batch,
- None,
- spec.ForwardPassMode.EVAL,
- rng,
- update_batch_norm=False)
+ params,
+ batch,
+ None,
+ spec.ForwardPassMode.EVAL,
+ rng,
+ update_batch_norm=False,
+ )
targets = batch['targets']
weights = batch.get('weights')
if weights is None:
weights = torch.ones(len(outputs), device=DEVICE)
weights_sum = weights.sum().to(torch.int)
ssim_sum = ssim(
- outputs[:weights_sum],
- targets[:weights_sum],
- mean=batch['mean'][:weights_sum],
- std=batch['std'][:weights_sum],
- volume_max=batch['volume_max'][:weights_sum]).sum()
+ outputs[:weights_sum],
+ targets[:weights_sum],
+ mean=batch['mean'][:weights_sum],
+ std=batch['std'][:weights_sum],
+ volume_max=batch['volume_max'][:weights_sum],
+ ).sum()
summed_loss = self.loss_fn(targets, outputs, weights)['summed']
return {'ssim': ssim_sum, 'loss': summed_loss}
- def _eval_model_on_split(self,
- split: str,
- num_examples: int,
- global_batch_size: int,
- params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- rng: spec.RandomState,
- data_dir: str,
- global_step: int = 0) -> Dict[str, float]:
+ def _eval_model_on_split(
+ self,
+ split: str,
+ num_examples: int,
+ global_batch_size: int,
+ params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ rng: spec.RandomState,
+ data_dir: str,
+ global_step: int = 0,
+ ) -> Dict[str, float]:
"""Run a full evaluation of the model."""
del model_state
del global_step
@@ -233,22 +249,23 @@ def _eval_model_on_split(self,
if split not in self._eval_iters:
# These iterators repeat indefinitely.
self._eval_iters[split] = self._build_input_queue(
- data_rng,
- split,
- data_dir,
- global_batch_size=global_batch_size,
- repeat_final_dataset=True,
- num_batches=num_batches)
+ data_rng,
+ split,
+ data_dir,
+ global_batch_size=global_batch_size,
+ repeat_final_dataset=True,
+ num_batches=num_batches,
+ )
total_metrics = {
- 'ssim': torch.tensor(0., device=DEVICE),
- 'loss': torch.tensor(0., device=DEVICE),
+ 'ssim': torch.tensor(0.0, device=DEVICE),
+ 'loss': torch.tensor(0.0, device=DEVICE),
}
for _ in range(num_batches):
batch = next(self._eval_iters[split])
batch_metrics = self._eval_model(params, batch, model_rng)
total_metrics = {
- k: v + batch_metrics[k] for k, v in total_metrics.items()
+ k: v + batch_metrics[k] for k, v in total_metrics.items()
}
if USE_PYTORCH_DDP:
for metric in total_metrics.values():
@@ -257,7 +274,6 @@ def _eval_model_on_split(self,
class FastMRIModelSizeWorkload(FastMRIWorkload):
-
@property
def num_pool_layers(self) -> bool:
"""Whether or not to use tanh activations in the model."""
@@ -278,7 +294,6 @@ def test_target_value(self) -> float:
class FastMRITanhWorkload(FastMRIWorkload):
-
@property
def use_tanh(self) -> bool:
"""Whether or not to use tanh activations in the model."""
@@ -294,7 +309,6 @@ def test_target_value(self) -> float:
class FastMRILayerNormWorkload(FastMRIWorkload):
-
@property
def use_layer_norm(self) -> bool:
"""Whether or not to use tanh activations in the model."""
diff --git a/algoperf/workloads/fastmri/input_pipeline.py b/algoperf/workloads/fastmri/input_pipeline.py
index f20611f43..62b3219c5 100644
--- a/algoperf/workloads/fastmri/input_pipeline.py
+++ b/algoperf/workloads/fastmri/input_pipeline.py
@@ -16,12 +16,9 @@
_EVAL_SEED = 0
-def _process_example(kspace,
- kspace_shape,
- target,
- target_shape,
- volume_max,
- seed):
+def _process_example(
+ kspace, kspace_shape, target, target_shape, volume_max, seed
+):
"""Generate a single example (slice from mri image).
Args:
@@ -45,15 +42,17 @@ def _process_example(kspace,
acceleration = tf.convert_to_tensor(4.0, dtype=tf.float32)
num_low_frequencies = tf.cast(
- num_cols_float * center_fraction, dtype=tf.int32)
+ num_cols_float * center_fraction, dtype=tf.int32
+ )
# calculate_center_mask
mask = tf.zeros(num_cols, dtype=tf.float32)
pad = (num_cols - num_low_frequencies + 1) // 2
mask = tf.tensor_scatter_nd_update(
- mask,
- tf.reshape(tf.range(pad, pad + num_low_frequencies), (-1, 1)),
- tf.ones(num_low_frequencies))
+ mask,
+ tf.reshape(tf.range(pad, pad + num_low_frequencies), (-1, 1)),
+ tf.ones(num_low_frequencies),
+ )
# reshape_mask
center_mask = tf.reshape(mask, (1, num_cols))
@@ -61,10 +60,12 @@ def _process_example(kspace,
# calculate_acceleration_mask
num_low_frequencies_float = tf.cast(num_low_frequencies, dtype=tf.float32)
prob = (num_cols_float / acceleration - num_low_frequencies_float) / (
- num_cols_float - num_low_frequencies_float)
+ num_cols_float - num_low_frequencies_float
+ )
mask = tf.cast(
- tf.random.stateless_uniform((num_cols,), seed) < prob, dtype=tf.float32)
+ tf.random.stateless_uniform((num_cols,), seed) < prob, dtype=tf.float32
+ )
acceleration_mask = tf.reshape(mask, (1, num_cols))
mask = tf.math.maximum(center_mask, acceleration_mask)
@@ -78,9 +79,11 @@ def _process_example(kspace,
shifted_image = tf.signal.ifft2d(shifted_kspace)
image = tf.signal.fftshift(shifted_image, axes=(0, 1))
scaling_norm = tf.cast(
- tf.math.sqrt(
- tf.cast(tf.math.reduce_prod(tf.shape(kspace)[-2:]), 'float32')),
- kspace.dtype)
+ tf.math.sqrt(
+ tf.cast(tf.math.reduce_prod(tf.shape(kspace)[-2:]), 'float32')
+ ),
+ kspace.dtype,
+ )
image = image * scaling_norm
image = tf.stack((tf.math.real(image), tf.math.imag(image)), axis=-1)
@@ -108,48 +111,58 @@ def _process_example(kspace,
target = tf.clip_by_value(norm_target, -6, 6)
return {
- 'inputs': image,
- 'targets': target,
- 'mean': mean,
- 'std': std,
- 'volume_max': volume_max,
+ 'inputs': image,
+ 'targets': target,
+ 'mean': mean,
+ 'std': std,
+ 'volume_max': volume_max,
}
def _h5_to_examples(path, log=False):
"""Yield MRI slices from an hdf5 file containing a single MRI volume."""
if log:
- tf.print('fastmri_dataset._h5_to_examples call:',
- path,
- datetime.datetime.now().strftime('%H:%M:%S:%f'))
+ tf.print(
+ 'fastmri_dataset._h5_to_examples call:',
+ path,
+ datetime.datetime.now().strftime('%H:%M:%S:%f'),
+ )
with open(path, 'rb') as gf:
with h5py.File(gf, 'r') as hf:
# NOTE(dsuo): logic taken from reference code
volume_max = hf.attrs.get('max', 0.0)
for i in range(hf['kspace'].shape[0]):
- yield hf['kspace'][i], hf['kspace'][i].shape, hf['reconstruction_esc'][
- i], hf['reconstruction_esc'][i].shape, volume_max
+ yield (
+ hf['kspace'][i],
+ hf['kspace'][i].shape,
+ hf['reconstruction_esc'][i],
+ hf['reconstruction_esc'][i].shape,
+ volume_max,
+ )
def _create_generator(filename):
signature = (
- tf.TensorSpec(shape=(640, None), dtype=tf.complex64),
- tf.TensorSpec(shape=(2,), dtype=tf.int32),
- tf.TensorSpec(shape=(320, 320), dtype=tf.float32),
- tf.TensorSpec(shape=(2,), dtype=tf.int32),
- tf.TensorSpec(shape=(), dtype=tf.float32),
+ tf.TensorSpec(shape=(640, None), dtype=tf.complex64),
+ tf.TensorSpec(shape=(2,), dtype=tf.int32),
+ tf.TensorSpec(shape=(320, 320), dtype=tf.float32),
+ tf.TensorSpec(shape=(2,), dtype=tf.int32),
+ tf.TensorSpec(shape=(), dtype=tf.float32),
)
return tf.data.Dataset.from_generator(
- _h5_to_examples, args=(filename,), output_signature=signature)
+ _h5_to_examples, args=(filename,), output_signature=signature
+ )
-def load_fastmri_split(global_batch_size,
- split,
- data_dir,
- shuffle_rng,
- num_batches,
- repeat_final_eval_dataset):
+def load_fastmri_split(
+ global_batch_size,
+ split,
+ data_dir,
+ shuffle_rng,
+ num_batches,
+ repeat_final_eval_dataset,
+):
"""Creates a split from the FastMRI dataset using tf.data.
NOTE: only creates knee singlecoil datasets.
@@ -169,11 +182,13 @@ def load_fastmri_split(global_batch_size,
# Check if data directories exist because glob will not raise an error
if not os.path.exists(os.path.join(data_dir, _TRAIN_DIR)):
- raise NotADirectoryError('Directory not found: {}'.format(
- os.path.join(data_dir, _TRAIN_DIR)))
+ raise NotADirectoryError(
+ 'Directory not found: {}'.format(os.path.join(data_dir, _TRAIN_DIR))
+ )
if not os.path.exists(os.path.join(data_dir, _VAL_DIR)):
- raise NotADirectoryError('Directory not found: {}'.format(
- os.path.join(data_dir, _VAL_DIR)))
+ raise NotADirectoryError(
+ 'Directory not found: {}'.format(os.path.join(data_dir, _VAL_DIR))
+ )
if split in ['train', 'eval_train']:
file_pattern = os.path.join(data_dir, _TRAIN_DIR, '*.h5')
@@ -190,10 +205,8 @@ def load_fastmri_split(global_batch_size,
shuffle = is_train or split == 'eval_train'
ds = tf.data.Dataset.from_tensor_slices(h5_paths)
ds = ds.interleave(
- _create_generator,
- cycle_length=32,
- block_length=64,
- num_parallel_calls=16)
+ _create_generator, cycle_length=32, block_length=64, num_parallel_calls=16
+ )
if is_train:
ds = ds.cache()
@@ -201,7 +214,8 @@ def process_example(example_index, example):
if shuffle:
process_rng = tf.cast(jax.random.fold_in(shuffle_rng, 0), tf.int64)
process_rng = tf.random.experimental.stateless_fold_in(
- process_rng, example_index)
+ process_rng, example_index
+ )
else:
# NOTE(dsuo): we use fixed randomness for eval.
process_rng = tf.cast(jax.random.PRNGKey(_EVAL_SEED), tf.int64)
@@ -211,9 +225,8 @@ def process_example(example_index, example):
if shuffle:
ds = ds.shuffle(
- 16 * global_batch_size,
- seed=shuffle_rng[0],
- reshuffle_each_iteration=True)
+ 16 * global_batch_size, seed=shuffle_rng[0], reshuffle_each_iteration=True
+ )
if is_train:
ds = ds.repeat()
@@ -231,7 +244,8 @@ def process_example(example_index, example):
ds = ds.repeat()
ds = ds.prefetch(10)
return map(
- functools.partial(
- data_utils.shard_and_maybe_pad_np,
- global_batch_size=global_batch_size),
- ds)
+ functools.partial(
+ data_utils.shard_and_maybe_pad_np, global_batch_size=global_batch_size
+ ),
+ ds,
+ )
diff --git a/algoperf/workloads/fastmri/workload.py b/algoperf/workloads/fastmri/workload.py
index 051749cc3..0b1ecfaa1 100644
--- a/algoperf/workloads/fastmri/workload.py
+++ b/algoperf/workloads/fastmri/workload.py
@@ -8,7 +8,6 @@
class BaseFastMRIWorkload(spec.Workload):
-
@property
def target_metric_name(self) -> str:
"""The name of the target metric (useful for scoring/processing code)."""
@@ -61,8 +60,9 @@ def num_eval_train_examples(self) -> int:
# Round up from num_validation_examples (which is the default for
# num_eval_train_examples) to the next multiple of eval_batch_size, so that
# we don't have to extract the correctly sized subset of the training data.
- rounded_up_multiple = math.ceil(self.num_validation_examples /
- self.eval_batch_size)
+ rounded_up_multiple = math.ceil(
+ self.num_validation_examples / self.eval_batch_size
+ )
return rounded_up_multiple * self.eval_batch_size
@property
@@ -106,18 +106,22 @@ def step_hint(self) -> int:
"""Approx. steps the baseline can do in the allowed runtime budget."""
return 18_094
- def _build_input_queue(self,
- data_rng: spec.RandomState,
- split: str,
- data_dir: str,
- global_batch_size: int,
- cache: Optional[bool] = None,
- repeat_final_dataset: Optional[bool] = None,
- num_batches: Optional[int] = None):
+ def _build_input_queue(
+ self,
+ data_rng: spec.RandomState,
+ split: str,
+ data_dir: str,
+ global_batch_size: int,
+ cache: Optional[bool] = None,
+ repeat_final_dataset: Optional[bool] = None,
+ num_batches: Optional[int] = None,
+ ):
del cache
- return input_pipeline.load_fastmri_split(global_batch_size,
- split,
- data_dir,
- data_rng,
- num_batches,
- repeat_final_dataset)
+ return input_pipeline.load_fastmri_split(
+ global_batch_size,
+ split,
+ data_dir,
+ data_rng,
+ num_batches,
+ repeat_final_dataset,
+ )
diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py
index 3d6939218..53368b384 100644
--- a/algoperf/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py
+++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py
@@ -1,5 +1,5 @@
"""
-Note:
+Note:
The following code is adapted from:
https://github.com/tensorflow/addons/tree/master/tensorflow_addons/image
@@ -12,35 +12,39 @@
import tensorflow as tf
_IMAGE_DTYPES = {
- tf.dtypes.uint8,
- tf.dtypes.int32,
- tf.dtypes.int64,
- tf.dtypes.float16,
- tf.dtypes.float32,
- tf.dtypes.float64,
+ tf.dtypes.uint8,
+ tf.dtypes.int32,
+ tf.dtypes.int64,
+ tf.dtypes.float16,
+ tf.dtypes.float32,
+ tf.dtypes.float64,
}
-Number = Union[float,
- int,
- np.float16,
- np.float32,
- np.float64,
- np.int8,
- np.int16,
- np.int32,
- np.int64,
- np.uint8,
- np.uint16,
- np.uint32,
- np.uint64,]
-
-TensorLike = Union[List[Union[Number, list]],
- tuple,
- Number,
- np.ndarray,
- tf.Tensor,
- tf.SparseTensor,
- tf.Variable,]
+Number = Union[
+ float,
+ int,
+ np.float16,
+ np.float32,
+ np.float64,
+ np.int8,
+ np.int16,
+ np.int32,
+ np.int64,
+ np.uint8,
+ np.uint16,
+ np.uint32,
+ np.uint64,
+]
+
+TensorLike = Union[
+ List[Union[Number, list]],
+ tuple,
+ Number,
+ np.ndarray,
+ tf.Tensor,
+ tf.SparseTensor,
+ tf.Variable,
+]
def get_ndims(image):
@@ -50,16 +54,19 @@ def get_ndims(image):
def to_4d_image(image):
"""Convert 2/3/4D image to 4D image.
- Args:
- image: 2/3/4D `Tensor`.
+ Args:
+ image: 2/3/4D `Tensor`.
- Returns:
- 4D `Tensor` with the same type.
- """
- with tf.control_dependencies([
+ Returns:
+ 4D `Tensor` with the same type.
+ """
+ with tf.control_dependencies(
+ [
tf.debugging.assert_rank_in(
- image, [2, 3, 4], message="`image` must be 2/3/4D tensor")
- ]):
+ image, [2, 3, 4], message='`image` must be 2/3/4D tensor'
+ )
+ ]
+ ):
ndims = image.get_shape().ndims
if ndims is None:
return _dynamic_to_4d_image(image)
@@ -80,12 +87,12 @@ def _dynamic_to_4d_image(image):
left_pad = tf.cast(tf.less_equal(original_rank, 3), dtype=tf.int32)
right_pad = tf.cast(tf.equal(original_rank, 2), dtype=tf.int32)
new_shape = tf.concat(
- [
- tf.ones(shape=left_pad, dtype=tf.int32),
- shape,
- tf.ones(shape=right_pad, dtype=tf.int32),
- ],
- axis=0,
+ [
+ tf.ones(shape=left_pad, dtype=tf.int32),
+ shape,
+ tf.ones(shape=right_pad, dtype=tf.int32),
+ ],
+ axis=0,
)
return tf.reshape(image, new_shape)
@@ -93,16 +100,16 @@ def _dynamic_to_4d_image(image):
def from_4d_image(image, ndims):
"""Convert back to an image with `ndims` rank.
- Args:
- image: 4D `Tensor`.
- ndims: The original rank of the image.
+ Args:
+ image: 4D `Tensor`.
+ ndims: The original rank of the image.
- Returns:
- `ndims`-D `Tensor` with the same type.
- """
+ Returns:
+ `ndims`-D `Tensor` with the same type.
+ """
with tf.control_dependencies(
- [tf.debugging.assert_rank(image, 4,
- message="`image` must be 4D tensor")]):
+ [tf.debugging.assert_rank(image, 4, message='`image` must be 4D tensor')]
+ ):
if isinstance(ndims, tf.Tensor):
return _dynamic_from_4d_image(image, ndims)
elif ndims == 2:
@@ -125,63 +132,64 @@ def _dynamic_from_4d_image(image, original_rank):
def transform(
- images: TensorLike,
- transforms: TensorLike,
- interpolation: str = "nearest",
- fill_mode: str = "constant",
- output_shape: Optional[list] = None,
- name: Optional[str] = None,
- fill_value: TensorLike = 0.0,
+ images: TensorLike,
+ transforms: TensorLike,
+ interpolation: str = 'nearest',
+ fill_mode: str = 'constant',
+ output_shape: Optional[list] = None,
+ name: Optional[str] = None,
+ fill_value: TensorLike = 0.0,
) -> tf.Tensor:
"""Applies the given transform(s) to the image(s).
- Args:
- images: A tensor of shape (num_images, num_rows, num_columns,
- num_channels) (NHWC), (num_rows, num_columns, num_channels) (HWC), or
- (num_rows, num_columns) (HW).
- transforms: Projective transform matrix/matrices. A vector of length 8 or
- tensor of size N x 8. If one row of transforms is
- [a0, a1, a2, b0, b1, b2, c0, c1], then it maps the *output* point
- `(x, y)` to a transformed *input* point
- `(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`,
- where `k = c0 x + c1 y + 1`. The transforms are *inverted* compared to
- the transform mapping input points to output points. Note that
- gradients are not backpropagated into transformation parameters.
- interpolation: Interpolation mode.
- Supported values: "nearest", "bilinear".
- fill_mode: Points outside the boundaries of the input are filled according
- to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`).
- - *reflect*: `(d c b a | a b c d | d c b a)`
- The input is extended by reflecting about the edge of the last pixel.
- - *constant*: `(k k k k | a b c d | k k k k)`
- The input is extended by filling all values beyond the edge with the
- same constant value k = 0.
- - *wrap*: `(a b c d | a b c d | a b c d)`
- The input is extended by wrapping around to the opposite edge.
- - *nearest*: `(a a a a | a b c d | d d d d)`
- The input is extended by the nearest pixel.
- fill_value: a float represents the value to be filled outside the
- boundaries when `fill_mode` is "constant".
- output_shape: Output dimesion after the transform, [height, width].
- If None, output is the same size as input image.
-
- name: The name of the op.
-
- Returns:
- Image(s) with the same type and shape as `images`, with the given
- transform(s) applied. Transformed coordinates outside of the input image
- will be filled with zeros.
-
- Raises:
- TypeError: If `image` is an invalid type.
- ValueError: If output shape is not 1-D int32 Tensor.
- """
- with tf.name_scope(name or "transform"):
- image_or_images = tf.convert_to_tensor(images, name="images")
+ Args:
+ images: A tensor of shape (num_images, num_rows, num_columns,
+ num_channels) (NHWC), (num_rows, num_columns, num_channels) (HWC), or
+ (num_rows, num_columns) (HW).
+ transforms: Projective transform matrix/matrices. A vector of length 8 or
+ tensor of size N x 8. If one row of transforms is
+ [a0, a1, a2, b0, b1, b2, c0, c1], then it maps the *output* point
+ `(x, y)` to a transformed *input* point
+ `(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`,
+ where `k = c0 x + c1 y + 1`. The transforms are *inverted* compared to
+ the transform mapping input points to output points. Note that
+ gradients are not backpropagated into transformation parameters.
+ interpolation: Interpolation mode.
+ Supported values: "nearest", "bilinear".
+ fill_mode: Points outside the boundaries of the input are filled according
+ to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`).
+ - *reflect*: `(d c b a | a b c d | d c b a)`
+ The input is extended by reflecting about the edge of the last pixel.
+ - *constant*: `(k k k k | a b c d | k k k k)`
+ The input is extended by filling all values beyond the edge with the
+ same constant value k = 0.
+ - *wrap*: `(a b c d | a b c d | a b c d)`
+ The input is extended by wrapping around to the opposite edge.
+ - *nearest*: `(a a a a | a b c d | d d d d)`
+ The input is extended by the nearest pixel.
+ fill_value: a float represents the value to be filled outside the
+ boundaries when `fill_mode` is "constant".
+ output_shape: Output dimesion after the transform, [height, width].
+ If None, output is the same size as input image.
+
+ name: The name of the op.
+
+ Returns:
+ Image(s) with the same type and shape as `images`, with the given
+ transform(s) applied. Transformed coordinates outside of the input image
+ will be filled with zeros.
+
+ Raises:
+ TypeError: If `image` is an invalid type.
+ ValueError: If output shape is not 1-D int32 Tensor.
+ """
+ with tf.name_scope(name or 'transform'):
+ image_or_images = tf.convert_to_tensor(images, name='images')
transform_or_transforms = tf.convert_to_tensor(
- transforms, name="transforms", dtype=tf.dtypes.float32)
+ transforms, name='transforms', dtype=tf.dtypes.float32
+ )
if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES:
- raise TypeError("Invalid dtype %s." % image_or_images.dtype)
+ raise TypeError('Invalid dtype %s.' % image_or_images.dtype)
images = to_4d_image(image_or_images)
original_ndims = get_ndims(image_or_images)
@@ -189,61 +197,67 @@ def transform(
output_shape = tf.shape(images)[1:3]
output_shape = tf.convert_to_tensor(
- output_shape, tf.dtypes.int32, name="output_shape")
+ output_shape, tf.dtypes.int32, name='output_shape'
+ )
if not output_shape.get_shape().is_compatible_with([2]):
- raise ValueError("output_shape must be a 1-D Tensor of 2 elements: "
- "new_height, new_width")
+ raise ValueError(
+ 'output_shape must be a 1-D Tensor of 2 elements: new_height, new_width'
+ )
if len(transform_or_transforms.get_shape()) == 1:
transforms = transform_or_transforms[None]
elif transform_or_transforms.get_shape().ndims is None:
- raise ValueError("transforms rank must be statically known")
+ raise ValueError('transforms rank must be statically known')
elif len(transform_or_transforms.get_shape()) == 2:
transforms = transform_or_transforms
else:
transforms = transform_or_transforms
- raise ValueError("transforms should have rank 1 or 2, but got rank %d" %
- len(transforms.get_shape()))
+ raise ValueError(
+ 'transforms should have rank 1 or 2, but got rank %d'
+ % len(transforms.get_shape())
+ )
fill_value = tf.convert_to_tensor(
- fill_value, dtype=tf.float32, name="fill_value")
+ fill_value, dtype=tf.float32, name='fill_value'
+ )
output = tf.raw_ops.ImageProjectiveTransformV3(
- images=images,
- transforms=transforms,
- output_shape=output_shape,
- interpolation=interpolation.upper(),
- fill_mode=fill_mode.upper(),
- fill_value=fill_value,
+ images=images,
+ transforms=transforms,
+ output_shape=output_shape,
+ interpolation=interpolation.upper(),
+ fill_mode=fill_mode.upper(),
+ fill_value=fill_value,
)
return from_4d_image(output, original_ndims)
def angles_to_projective_transforms(
- angles: TensorLike,
- image_height: TensorLike,
- image_width: TensorLike,
- name: Optional[str] = None,
+ angles: TensorLike,
+ image_height: TensorLike,
+ image_width: TensorLike,
+ name: Optional[str] = None,
) -> tf.Tensor:
"""Returns projective transform(s) for the given angle(s).
- Args:
- angles: A scalar angle to rotate all images by, or (for batches of
- images) a vector with an angle to rotate each image in the batch. The
- rank must be statically known (the shape is not `TensorShape(None)`.
- image_height: Height of the image(s) to be transformed.
- image_width: Width of the image(s) to be transformed.
-
- Returns:
- A tensor of shape (num_images, 8). Projective transforms which can be
- given to `transform` op.
- """
- with tf.name_scope(name or "angles_to_projective_transforms"):
+ Args:
+ angles: A scalar angle to rotate all images by, or (for batches of
+ images) a vector with an angle to rotate each image in the batch. The
+ rank must be statically known (the shape is not `TensorShape(None)`.
+ image_height: Height of the image(s) to be transformed.
+ image_width: Width of the image(s) to be transformed.
+
+ Returns:
+ A tensor of shape (num_images, 8). Projective transforms which can be
+ given to `transform` op.
+ """
+ with tf.name_scope(name or 'angles_to_projective_transforms'):
angle_or_angles = tf.convert_to_tensor(
- angles, name="angles", dtype=tf.dtypes.float32)
+ angles, name='angles', dtype=tf.dtypes.float32
+ )
if len(angle_or_angles.get_shape()) not in (0, 1):
- raise ValueError("angles should have rank 0 or 1.")
+ raise ValueError('angles should have rank 0 or 1.')
if len(angle_or_angles.get_shape()) == 0:
angles = angle_or_angles[None]
@@ -252,112 +266,116 @@ def angles_to_projective_transforms(
cos_angles = tf.math.cos(angles)
sin_angles = tf.math.sin(angles)
- x_offset = ((image_width - 1) -
- (cos_angles * (image_width - 1) - sin_angles *
- (image_height - 1))) / 2.0
- y_offset = ((image_height - 1) -
- (sin_angles * (image_width - 1) + cos_angles *
- (image_height - 1))) / 2.0
+ x_offset = (
+ (image_width - 1)
+ - (cos_angles * (image_width - 1) - sin_angles * (image_height - 1))
+ ) / 2.0
+ y_offset = (
+ (image_height - 1)
+ - (sin_angles * (image_width - 1) + cos_angles * (image_height - 1))
+ ) / 2.0
num_angles = tf.shape(angles)[0]
return tf.concat(
- values=[
- cos_angles[:, None],
- -sin_angles[:, None],
- x_offset[:, None],
- sin_angles[:, None],
- cos_angles[:, None],
- y_offset[:, None],
- tf.zeros((num_angles, 2), tf.dtypes.float32),
- ],
- axis=1,
+ values=[
+ cos_angles[:, None],
+ -sin_angles[:, None],
+ x_offset[:, None],
+ sin_angles[:, None],
+ cos_angles[:, None],
+ y_offset[:, None],
+ tf.zeros((num_angles, 2), tf.dtypes.float32),
+ ],
+ axis=1,
)
def rotate_img(
- images: TensorLike,
- angles: TensorLike,
- interpolation: str = "nearest",
- fill_mode: str = "constant",
- name: Optional[str] = None,
- fill_value: TensorLike = 0.0,
+ images: TensorLike,
+ angles: TensorLike,
+ interpolation: str = 'nearest',
+ fill_mode: str = 'constant',
+ name: Optional[str] = None,
+ fill_value: TensorLike = 0.0,
) -> tf.Tensor:
"""Rotate image(s) counterclockwise by the passed angle(s) in radians.
- Args:
- images: A tensor of shape
- `(num_images, num_rows, num_columns, num_channels)`
- (NHWC), `(num_rows, num_columns, num_channels)` (HWC), or
- `(num_rows, num_columns)` (HW).
- angles: A scalar angle to rotate all images by (if `images` has rank 4)
- a vector of length num_images, with an angle for each image in the
- batch.
- interpolation: Interpolation mode. Supported values: "nearest",
- "bilinear".
- fill_mode: Points outside the boundaries of the input are filled according
- to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`).
- - *reflect*: `(d c b a | a b c d | d c b a)`
- The input is extended by reflecting about the edge of the last pixel.
- - *constant*: `(k k k k | a b c d | k k k k)`
- The input is extended by filling all values beyond the edge with the
- same constant value k = 0.
- - *wrap*: `(a b c d | a b c d | a b c d)`
- The input is extended by wrapping around to the opposite edge.
- - *nearest*: `(a a a a | a b c d | d d d d)`
- The input is extended by the nearest pixel.
- fill_value: a float represents the value to be filled outside the
- boundaries when `fill_mode` is "constant".
- name: The name of the op.
-
- Returns:
- Image(s) with the same type and shape as `images`, rotated by the given
- angle(s). Empty space due to the rotation will be filled with zeros.
-
- Raises:
- TypeError: If `images` is an invalid type.
- """
- with tf.name_scope(name or "rotate"):
+ Args:
+ images: A tensor of shape
+ `(num_images, num_rows, num_columns, num_channels)`
+ (NHWC), `(num_rows, num_columns, num_channels)` (HWC), or
+ `(num_rows, num_columns)` (HW).
+ angles: A scalar angle to rotate all images by (if `images` has rank 4)
+ a vector of length num_images, with an angle for each image in the
+ batch.
+ interpolation: Interpolation mode. Supported values: "nearest",
+ "bilinear".
+ fill_mode: Points outside the boundaries of the input are filled according
+ to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`).
+ - *reflect*: `(d c b a | a b c d | d c b a)`
+ The input is extended by reflecting about the edge of the last pixel.
+ - *constant*: `(k k k k | a b c d | k k k k)`
+ The input is extended by filling all values beyond the edge with the
+ same constant value k = 0.
+ - *wrap*: `(a b c d | a b c d | a b c d)`
+ The input is extended by wrapping around to the opposite edge.
+ - *nearest*: `(a a a a | a b c d | d d d d)`
+ The input is extended by the nearest pixel.
+ fill_value: a float represents the value to be filled outside the
+ boundaries when `fill_mode` is "constant".
+ name: The name of the op.
+
+ Returns:
+ Image(s) with the same type and shape as `images`, rotated by the given
+ angle(s). Empty space due to the rotation will be filled with zeros.
+
+ Raises:
+ TypeError: If `images` is an invalid type.
+ """
+ with tf.name_scope(name or 'rotate'):
image_or_images = tf.convert_to_tensor(images)
if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES:
- raise TypeError("Invalid dtype %s." % image_or_images.dtype)
+ raise TypeError('Invalid dtype %s.' % image_or_images.dtype)
images = to_4d_image(image_or_images)
original_ndims = get_ndims(image_or_images)
image_height = tf.cast(tf.shape(images)[1], tf.dtypes.float32)[None]
image_width = tf.cast(tf.shape(images)[2], tf.dtypes.float32)[None]
output = transform(
- images,
- angles_to_projective_transforms(angles, image_height, image_width),
- interpolation=interpolation,
- fill_mode=fill_mode,
- fill_value=fill_value,
+ images,
+ angles_to_projective_transforms(angles, image_height, image_width),
+ interpolation=interpolation,
+ fill_mode=fill_mode,
+ fill_value=fill_value,
)
return from_4d_image(output, original_ndims)
-def translations_to_projective_transforms(translations: TensorLike,
- name: Optional[str] = None
- ) -> tf.Tensor:
+def translations_to_projective_transforms(
+ translations: TensorLike, name: Optional[str] = None
+) -> tf.Tensor:
"""Returns projective transform(s) for the given translation(s).
- Args:
- translations: A 2-element list representing `[dx, dy]` or a matrix of
- 2-element lists representing `[dx, dy]` to translate for each image
- (for a batch of images). The rank must be statically known
- (the shape is not `TensorShape(None)`).
- name: The name of the op.
- Returns:
- A tensor of shape `(num_images, 8)` projective transforms which can be
- given to `tfa.image.transform`.
- """
- with tf.name_scope(name or "translations_to_projective_transforms"):
+ Args:
+ translations: A 2-element list representing `[dx, dy]` or a matrix of
+ 2-element lists representing `[dx, dy]` to translate for each image
+ (for a batch of images). The rank must be statically known
+ (the shape is not `TensorShape(None)`).
+ name: The name of the op.
+ Returns:
+ A tensor of shape `(num_images, 8)` projective transforms which can be
+ given to `tfa.image.transform`.
+ """
+ with tf.name_scope(name or 'translations_to_projective_transforms'):
translation_or_translations = tf.convert_to_tensor(
- translations, name="translations", dtype=tf.dtypes.float32)
+ translations, name='translations', dtype=tf.dtypes.float32
+ )
if translation_or_translations.get_shape().ndims is None:
raise TypeError(
- "translation_or_translations rank must be statically known")
+ 'translation_or_translations rank must be statically known'
+ )
if len(translation_or_translations.get_shape()) not in (1, 2):
- raise TypeError("Translations should have rank 1 or 2.")
+ raise TypeError('Translations should have rank 1 or 2.')
if len(translation_or_translations.get_shape()) == 1:
translations = translation_or_translations[None]
@@ -372,67 +390,67 @@ def translations_to_projective_transforms(translations: TensorLike,
# where the last entry is implicit.
# Translation matrices are always float32.
return tf.concat(
- values=[
- tf.ones((num_translations, 1), tf.dtypes.float32),
- tf.zeros((num_translations, 1), tf.dtypes.float32),
- -translations[:, 0, None],
- tf.zeros((num_translations, 1), tf.dtypes.float32),
- tf.ones((num_translations, 1), tf.dtypes.float32),
- -translations[:, 1, None],
- tf.zeros((num_translations, 2), tf.dtypes.float32),
- ],
- axis=1,
+ values=[
+ tf.ones((num_translations, 1), tf.dtypes.float32),
+ tf.zeros((num_translations, 1), tf.dtypes.float32),
+ -translations[:, 0, None],
+ tf.zeros((num_translations, 1), tf.dtypes.float32),
+ tf.ones((num_translations, 1), tf.dtypes.float32),
+ -translations[:, 1, None],
+ tf.zeros((num_translations, 2), tf.dtypes.float32),
+ ],
+ axis=1,
)
@tf.function
def translate(
- images: TensorLike,
- translations: TensorLike,
- interpolation: str = "nearest",
- fill_mode: str = "constant",
- name: Optional[str] = None,
- fill_value: TensorLike = 0.0,
+ images: TensorLike,
+ translations: TensorLike,
+ interpolation: str = 'nearest',
+ fill_mode: str = 'constant',
+ name: Optional[str] = None,
+ fill_value: TensorLike = 0.0,
) -> tf.Tensor:
"""Translate image(s) by the passed vectors(s).
- Args:
- images: A tensor of shape
- `(num_images, num_rows, num_columns, num_channels)` (NHWC),
- `(num_rows, num_columns, num_channels)` (HWC), or
- `(num_rows, num_columns)` (HW). The rank must be statically known (the
- shape is not `TensorShape(None)`).
- translations: A vector representing `[dx, dy]` or (if `images` has rank 4)
- a matrix of length num_images, with a `[dx, dy]` vector for each image
- in the batch.
- interpolation: Interpolation mode. Supported values: "nearest",
- "bilinear".
- fill_mode: Points outside the boundaries of the input are filled according
- to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`).
- - *reflect*: `(d c b a | a b c d | d c b a)`
- The input is extended by reflecting about the edge of the last pixel.
- - *constant*: `(k k k k | a b c d | k k k k)`
- The input is extended by filling all values beyond the edge with the
- same constant value k = 0.
- - *wrap*: `(a b c d | a b c d | a b c d)`
- The input is extended by wrapping around to the opposite edge.
- - *nearest*: `(a a a a | a b c d | d d d d)`
- The input is extended by the nearest pixel.
- fill_value: a float represents the value to be filled outside the
- boundaries when `fill_mode` is "constant".
- name: The name of the op.
- Returns:
- Image(s) with the same type and shape as `images`, translated by the
- given vector(s). Empty space due to the translation will be filled with
- zeros.
- Raises:
- TypeError: If `images` is an invalid type.
- """
- with tf.name_scope(name or "translate"):
+ Args:
+ images: A tensor of shape
+ `(num_images, num_rows, num_columns, num_channels)` (NHWC),
+ `(num_rows, num_columns, num_channels)` (HWC), or
+ `(num_rows, num_columns)` (HW). The rank must be statically known (the
+ shape is not `TensorShape(None)`).
+ translations: A vector representing `[dx, dy]` or (if `images` has rank 4)
+ a matrix of length num_images, with a `[dx, dy]` vector for each image
+ in the batch.
+ interpolation: Interpolation mode. Supported values: "nearest",
+ "bilinear".
+ fill_mode: Points outside the boundaries of the input are filled according
+ to the given mode (one of `{'constant', 'reflect', 'wrap', 'nearest'}`).
+ - *reflect*: `(d c b a | a b c d | d c b a)`
+ The input is extended by reflecting about the edge of the last pixel.
+ - *constant*: `(k k k k | a b c d | k k k k)`
+ The input is extended by filling all values beyond the edge with the
+ same constant value k = 0.
+ - *wrap*: `(a b c d | a b c d | a b c d)`
+ The input is extended by wrapping around to the opposite edge.
+ - *nearest*: `(a a a a | a b c d | d d d d)`
+ The input is extended by the nearest pixel.
+ fill_value: a float represents the value to be filled outside the
+ boundaries when `fill_mode` is "constant".
+ name: The name of the op.
+ Returns:
+ Image(s) with the same type and shape as `images`, translated by the
+ given vector(s). Empty space due to the translation will be filled with
+ zeros.
+ Raises:
+ TypeError: If `images` is an invalid type.
+ """
+ with tf.name_scope(name or 'translate'):
return transform(
- images,
- translations_to_projective_transforms(translations),
- interpolation=interpolation,
- fill_mode=fill_mode,
- fill_value=fill_value,
+ images,
+ translations_to_projective_transforms(translations),
+ interpolation=interpolation,
+ fill_mode=fill_mode,
+ fill_value=fill_value,
)
diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py
index 66105335b..fc42f4a5b 100644
--- a/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py
+++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py
@@ -17,19 +17,21 @@
from algoperf.workloads.imagenet_resnet.imagenet_jax import randaugment
TFDS_SPLIT_NAME = {
- 'train': 'train', 'eval_train': 'train', 'validation': 'validation'
+ 'train': 'train',
+ 'eval_train': 'train',
+ 'validation': 'validation',
}
-def _distorted_bounding_box_crop(image_bytes: spec.Tensor,
- rng: spec.RandomState,
- bbox: spec.Tensor,
- min_object_covered: float = 0.1,
- aspect_ratio_range: Tuple[float,
- float] = (0.75,
- 1.33),
- area_range: Tuple[float, float] = (0.05, 1.0),
- max_attempts: int = 100) -> spec.Tensor:
+def _distorted_bounding_box_crop(
+ image_bytes: spec.Tensor,
+ rng: spec.RandomState,
+ bbox: spec.Tensor,
+ min_object_covered: float = 0.1,
+ aspect_ratio_range: Tuple[float, float] = (0.75, 1.33),
+ area_range: Tuple[float, float] = (0.05, 1.0),
+ max_attempts: int = 100,
+) -> spec.Tensor:
"""Generates cropped_image using one of the bboxes randomly distorted.
See `tf.image.sample_distorted_bounding_box` for more documentation.
@@ -57,14 +59,15 @@ def _distorted_bounding_box_crop(image_bytes: spec.Tensor,
"""
shape = tf.io.extract_jpeg_shape(image_bytes)
bbox_begin, bbox_size, _ = tf.image.stateless_sample_distorted_bounding_box(
- shape,
- seed=rng,
- bounding_boxes=bbox,
- min_object_covered=min_object_covered,
- aspect_ratio_range=aspect_ratio_range,
- area_range=area_range,
- max_attempts=max_attempts,
- use_image_if_no_bounding_boxes=True)
+ shape,
+ seed=rng,
+ bounding_boxes=bbox,
+ min_object_covered=min_object_covered,
+ aspect_ratio_range=aspect_ratio_range,
+ area_range=area_range,
+ max_attempts=max_attempts,
+ use_image_if_no_bounding_boxes=True,
+ )
# Crop the image to the specified bounding box.
offset_y, offset_x, _ = tf.unstack(bbox_begin)
@@ -84,8 +87,9 @@ def resize(image: spec.Tensor, image_size: int) -> spec.Tensor:
Returns:
Resized image 'Tensor'.
"""
- return tf.image.resize([image], [image_size, image_size],
- method=tf.image.ResizeMethod.BICUBIC)[0]
+ return tf.image.resize(
+ [image], [image_size, image_size], method=tf.image.ResizeMethod.BICUBIC
+ )[0]
def _at_least_x_are_equal(a: spec.Tensor, b: spec.Tensor, x: float) -> bool:
@@ -95,80 +99,93 @@ def _at_least_x_are_equal(a: spec.Tensor, b: spec.Tensor, x: float) -> bool:
return tf.greater_equal(tf.reduce_sum(match), x)
-def _decode_and_random_crop(image_bytes: spec.Tensor,
- rng: spec.RandomState,
- image_size: int,
- aspect_ratio_range: Tuple[float, float],
- area_range: Tuple[float, float],
- resize_size: int) -> spec.Tensor:
+def _decode_and_random_crop(
+ image_bytes: spec.Tensor,
+ rng: spec.RandomState,
+ image_size: int,
+ aspect_ratio_range: Tuple[float, float],
+ area_range: Tuple[float, float],
+ resize_size: int,
+) -> spec.Tensor:
"""Make a random crop of image_size."""
bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
image = _distorted_bounding_box_crop(
- image_bytes,
- rng,
- bbox,
- min_object_covered=0.1,
- aspect_ratio_range=aspect_ratio_range,
- area_range=area_range,
- max_attempts=10)
+ image_bytes,
+ rng,
+ bbox,
+ min_object_covered=0.1,
+ aspect_ratio_range=aspect_ratio_range,
+ area_range=area_range,
+ max_attempts=10,
+ )
original_shape = tf.io.extract_jpeg_shape(image_bytes)
bad = _at_least_x_are_equal(original_shape, tf.shape(image), 3)
image = tf.cond(
- bad,
- lambda: _decode_and_center_crop(image_bytes, image_size, resize_size),
- lambda: resize(image, image_size))
+ bad,
+ lambda: _decode_and_center_crop(image_bytes, image_size, resize_size),
+ lambda: resize(image, image_size),
+ )
return image
-def _decode_and_center_crop(image_bytes: spec.Tensor,
- image_size: int,
- resize_size: int) -> spec.Tensor:
+def _decode_and_center_crop(
+ image_bytes: spec.Tensor, image_size: int, resize_size: int
+) -> spec.Tensor:
"""Crops to center of image with padding then scales image_size."""
shape = tf.io.extract_jpeg_shape(image_bytes)
image_height = shape[0]
image_width = shape[1]
padded_center_crop_size = tf.cast(
- ((image_size / resize_size) *
- tf.cast(tf.minimum(image_height, image_width), tf.float32)),
- tf.int32)
+ (
+ (image_size / resize_size)
+ * tf.cast(tf.minimum(image_height, image_width), tf.float32)
+ ),
+ tf.int32,
+ )
offset_height = ((image_height - padded_center_crop_size) + 1) // 2
offset_width = ((image_width - padded_center_crop_size) + 1) // 2
- crop_window = tf.stack([
+ crop_window = tf.stack(
+ [
offset_height,
offset_width,
padded_center_crop_size,
padded_center_crop_size,
- ])
+ ]
+ )
image = tf.io.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)
image = resize(image, image_size)
return image
-def normalize_image(image: spec.Tensor,
- mean_rgb: Tuple[float, float, float],
- stddev_rgb: Tuple[float, float, float]) -> spec.Tensor:
+def normalize_image(
+ image: spec.Tensor,
+ mean_rgb: Tuple[float, float, float],
+ stddev_rgb: Tuple[float, float, float],
+) -> spec.Tensor:
image -= tf.constant(mean_rgb, shape=[1, 1, 3], dtype=image.dtype)
image /= tf.constant(stddev_rgb, shape=[1, 1, 3], dtype=image.dtype)
return image
-def preprocess_for_train(image_bytes: spec.Tensor,
- rng: spec.RandomState,
- mean_rgb: Tuple[float, float, float],
- stddev_rgb: Tuple[float, float, float],
- aspect_ratio_range: Tuple[float, float],
- area_range: Tuple[float, float],
- image_size: int,
- resize_size: int,
- dtype: tf.DType = tf.float32,
- use_randaug: bool = False,
- randaug_num_layers: int = 2,
- randaug_magnitude: int = 10) -> spec.Tensor:
+def preprocess_for_train(
+ image_bytes: spec.Tensor,
+ rng: spec.RandomState,
+ mean_rgb: Tuple[float, float, float],
+ stddev_rgb: Tuple[float, float, float],
+ aspect_ratio_range: Tuple[float, float],
+ area_range: Tuple[float, float],
+ image_size: int,
+ resize_size: int,
+ dtype: tf.DType = tf.float32,
+ use_randaug: bool = False,
+ randaug_num_layers: int = 2,
+ randaug_magnitude: int = 10,
+) -> spec.Tensor:
"""Preprocesses the given image for training.
Args:
@@ -182,33 +199,36 @@ def preprocess_for_train(image_bytes: spec.Tensor,
"""
rngs = tf.random.experimental.stateless_split(rng, 3)
- image = _decode_and_random_crop(image_bytes,
- rngs[0],
- image_size,
- aspect_ratio_range,
- area_range,
- resize_size)
+ image = _decode_and_random_crop(
+ image_bytes,
+ rngs[0],
+ image_size,
+ aspect_ratio_range,
+ area_range,
+ resize_size,
+ )
image = tf.reshape(image, [image_size, image_size, 3])
image = tf.image.stateless_random_flip_left_right(image, seed=rngs[1])
if use_randaug:
image = tf.cast(tf.clip_by_value(image, 0, 255), tf.uint8)
- image = randaugment.distort_image_with_randaugment(image,
- randaug_num_layers,
- randaug_magnitude,
- rngs[2])
+ image = randaugment.distort_image_with_randaugment(
+ image, randaug_num_layers, randaug_magnitude, rngs[2]
+ )
image = tf.cast(image, tf.float32)
image = normalize_image(image, mean_rgb, stddev_rgb)
image = tf.image.convert_image_dtype(image, dtype=dtype)
return image
-def preprocess_for_eval(image_bytes: spec.Tensor,
- mean_rgb: Tuple[float, float, float],
- stddev_rgb: Tuple[float, float, float],
- image_size: int,
- resize_size: int,
- dtype: tf.DType = tf.float32) -> spec.Tensor:
+def preprocess_for_eval(
+ image_bytes: spec.Tensor,
+ mean_rgb: Tuple[float, float, float],
+ stddev_rgb: Tuple[float, float, float],
+ image_size: int,
+ resize_size: int,
+ dtype: tf.DType = tf.float32,
+) -> spec.Tensor:
"""Preprocesses the given image for evaluation.
Args:
@@ -229,10 +249,12 @@ def preprocess_for_eval(image_bytes: spec.Tensor,
# Modified from
# github.com/google/init2winit/blob/master/init2winit/dataset_lib/ (cont. below)
# image_preprocessing.py.
-def mixup_tf(key: spec.RandomState,
- inputs: spec.Tensor,
- targets: spec.Tensor,
- alpha: float = 0.2) -> Tuple[spec.Tensor, spec.Tensor]:
+def mixup_tf(
+ key: spec.RandomState,
+ inputs: spec.Tensor,
+ targets: spec.Tensor,
+ alpha: float = 0.2,
+) -> Tuple[spec.Tensor, spec.Tensor]:
"""Perform mixup https://arxiv.org/abs/1710.09412.
NOTE: Code taken from https://github.com/google/big_vision with variables
@@ -261,24 +283,26 @@ def mixup_tf(key: spec.RandomState,
return inputs, targets
-def create_split(split,
- dataset_builder,
- rng,
- global_batch_size,
- train,
- image_size,
- resize_size,
- mean_rgb,
- stddev_rgb,
- cache=False,
- repeat_final_dataset=False,
- aspect_ratio_range=(0.75, 4.0 / 3.0),
- area_range=(0.08, 1.0),
- use_mixup=False,
- mixup_alpha=0.1,
- use_randaug=False,
- randaug_num_layers=2,
- randaug_magnitude=10) -> Iterator[Dict[str, spec.Tensor]]:
+def create_split(
+ split,
+ dataset_builder,
+ rng,
+ global_batch_size,
+ train,
+ image_size,
+ resize_size,
+ mean_rgb,
+ stddev_rgb,
+ cache=False,
+ repeat_final_dataset=False,
+ aspect_ratio_range=(0.75, 4.0 / 3.0),
+ area_range=(0.08, 1.0),
+ use_mixup=False,
+ mixup_alpha=0.1,
+ use_randaug=False,
+ randaug_num_layers=2,
+ randaug_magnitude=10,
+) -> Iterator[Dict[str, spec.Tensor]]:
"""Creates a split from the ImageNet dataset using TensorFlow Datasets."""
shuffle_rng, preprocess_rng, mixup_rng = jax.random.split(rng, 3)
@@ -286,34 +310,35 @@ def decode_example(example_index, example):
dtype = tf.float32
if train:
per_step_preprocess_rng = tf.random.experimental.stateless_fold_in(
- tf.cast(preprocess_rng, tf.int64), example_index)
-
- image = preprocess_for_train(example['image'],
- per_step_preprocess_rng,
- mean_rgb,
- stddev_rgb,
- aspect_ratio_range,
- area_range,
- image_size,
- resize_size,
- dtype,
- use_randaug,
- randaug_num_layers,
- randaug_magnitude)
+ tf.cast(preprocess_rng, tf.int64), example_index
+ )
+
+ image = preprocess_for_train(
+ example['image'],
+ per_step_preprocess_rng,
+ mean_rgb,
+ stddev_rgb,
+ aspect_ratio_range,
+ area_range,
+ image_size,
+ resize_size,
+ dtype,
+ use_randaug,
+ randaug_num_layers,
+ randaug_magnitude,
+ )
else:
- image = preprocess_for_eval(example['image'],
- mean_rgb,
- stddev_rgb,
- image_size,
- resize_size,
- dtype)
+ image = preprocess_for_eval(
+ example['image'], mean_rgb, stddev_rgb, image_size, resize_size, dtype
+ )
return {'inputs': image, 'targets': example['label']}
ds = dataset_builder.as_dataset(
- split=TFDS_SPLIT_NAME[split],
- decoders={
- 'image': tfds.decode.SkipDecoding(),
- })
+ split=TFDS_SPLIT_NAME[split],
+ decoders={
+ 'image': tfds.decode.SkipDecoding(),
+ },
+ )
options = tf.data.Options()
options.threading.private_threadpool_size = 48
ds = ds.with_options(options)
@@ -336,18 +361,21 @@ def decode_example(example_index, example):
def mixup_batch(batch_index, batch):
per_batch_mixup_rng = tf.random.experimental.stateless_fold_in(
- mixup_rng, batch_index)
+ mixup_rng, batch_index
+ )
(inputs, targets) = mixup_tf(
- per_batch_mixup_rng,
- batch['inputs'],
- batch['targets'],
- alpha=mixup_alpha)
+ per_batch_mixup_rng,
+ batch['inputs'],
+ batch['targets'],
+ alpha=mixup_alpha,
+ )
batch['inputs'] = inputs
batch['targets'] = targets
return batch
ds = ds.enumerate().map(
- mixup_batch, num_parallel_calls=tf.data.experimental.AUTOTUNE)
+ mixup_batch, num_parallel_calls=tf.data.experimental.AUTOTUNE
+ )
else:
raise ValueError('Mixup can only be used for the training split.')
@@ -359,44 +387,48 @@ def mixup_batch(batch_index, batch):
return ds
-def create_input_iter(split: str,
- dataset_builder: tfds.core.dataset_builder.DatasetBuilder,
- rng: spec.RandomState,
- global_batch_size: int,
- mean_rgb: Tuple[float, float, float],
- stddev_rgb: Tuple[float, float, float],
- image_size: int,
- resize_size: int,
- aspect_ratio_range: Tuple[float, float],
- area_range: Tuple[float, float],
- train: bool,
- cache: bool,
- repeat_final_dataset: bool,
- use_mixup: bool,
- mixup_alpha: float,
- use_randaug: bool) -> Iterator[Dict[str, spec.Tensor]]:
+def create_input_iter(
+ split: str,
+ dataset_builder: tfds.core.dataset_builder.DatasetBuilder,
+ rng: spec.RandomState,
+ global_batch_size: int,
+ mean_rgb: Tuple[float, float, float],
+ stddev_rgb: Tuple[float, float, float],
+ image_size: int,
+ resize_size: int,
+ aspect_ratio_range: Tuple[float, float],
+ area_range: Tuple[float, float],
+ train: bool,
+ cache: bool,
+ repeat_final_dataset: bool,
+ use_mixup: bool,
+ mixup_alpha: float,
+ use_randaug: bool,
+) -> Iterator[Dict[str, spec.Tensor]]:
ds = create_split(
- split,
- dataset_builder,
- rng,
- global_batch_size,
- train=train,
- image_size=image_size,
- resize_size=resize_size,
- mean_rgb=mean_rgb,
- stddev_rgb=stddev_rgb,
- cache=cache,
- repeat_final_dataset=repeat_final_dataset,
- aspect_ratio_range=aspect_ratio_range,
- area_range=area_range,
- use_mixup=use_mixup,
- mixup_alpha=mixup_alpha,
- use_randaug=use_randaug)
+ split,
+ dataset_builder,
+ rng,
+ global_batch_size,
+ train=train,
+ image_size=image_size,
+ resize_size=resize_size,
+ mean_rgb=mean_rgb,
+ stddev_rgb=stddev_rgb,
+ cache=cache,
+ repeat_final_dataset=repeat_final_dataset,
+ aspect_ratio_range=aspect_ratio_range,
+ area_range=area_range,
+ use_mixup=use_mixup,
+ mixup_alpha=mixup_alpha,
+ use_randaug=use_randaug,
+ )
it = map(
- functools.partial(
- data_utils.shard_and_maybe_pad_np,
- global_batch_size=global_batch_size),
- ds)
+ functools.partial(
+ data_utils.shard_and_maybe_pad_np, global_batch_size=global_batch_size
+ ),
+ ds,
+ )
# Note(Dan S): On a Nvidia 2080 Ti GPU, this increased GPU utilization by 10%.
it = jax_utils.prefetch_to_device(it, 2)
diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py
index ffa60b260..1f3911708 100644
--- a/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py
+++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py
@@ -17,12 +17,13 @@
class ResNetBlock(nn.Module):
"""ResNet block."""
+
filters: int
conv: ModuleDef
norm: ModuleDef
act: Callable
strides: Tuple[int, int] = (1, 1)
- bn_init_scale: float = 0.
+ bn_init_scale: float = 0.0
@nn.compact
def __call__(self, x: spec.Tensor) -> spec.Tensor:
@@ -35,8 +36,8 @@ def __call__(self, x: spec.Tensor) -> spec.Tensor:
if residual.shape != y.shape or self.strides != (1, 1):
residual = self.conv(
- self.filters, (1, 1), self.strides, name='Conv_proj')(
- residual)
+ self.filters, (1, 1), self.strides, name='Conv_proj'
+ )(residual)
residual = self.norm(name='BatchNorm_proj')(residual)
return self.act(residual + y)
@@ -44,6 +45,7 @@ def __call__(self, x: spec.Tensor) -> spec.Tensor:
class BottleneckResNetBlock(nn.Module):
"""Bottleneck ResNet block."""
+
filters: int
conv: ModuleDef
norm: ModuleDef
@@ -65,8 +67,8 @@ def __call__(self, x: spec.Tensor) -> spec.Tensor:
if residual.shape != y.shape or self.strides != (1, 1):
residual = self.conv(
- self.filters * 4, (1, 1), self.strides, name='Conv_proj')(
- residual)
+ self.filters * 4, (1, 1), self.strides, name='Conv_proj'
+ )(residual)
residual = self.norm(name='BatchNorm_proj')(residual)
return self.act(residual + y)
@@ -79,30 +81,35 @@ class ResNet(nn.Module):
num_filters: int = 64
dtype: Any = jnp.float32
act: Callable = nn.relu
- bn_init_scale: float = 0.
+ bn_init_scale: float = 0.0
@nn.compact
- def __call__(self,
- x: spec.Tensor,
- update_batch_norm: bool = True,
- use_running_average_bn: Optional[bool] = None) -> spec.Tensor:
+ def __call__(
+ self,
+ x: spec.Tensor,
+ update_batch_norm: bool = True,
+ use_running_average_bn: Optional[bool] = None,
+ ) -> spec.Tensor:
conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype)
# Preserve default behavior for backwards compatibility
if use_running_average_bn is None:
use_running_average_bn = not update_batch_norm
norm = functools.partial(
- nn.BatchNorm,
- use_running_average=use_running_average_bn,
- momentum=0.9,
- epsilon=1e-5,
- dtype=self.dtype)
+ nn.BatchNorm,
+ use_running_average=use_running_average_bn,
+ momentum=0.9,
+ epsilon=1e-5,
+ dtype=self.dtype,
+ )
x = conv(
- self.num_filters, (7, 7), (2, 2),
- padding=[(3, 3), (3, 3)],
- name='Conv_init')(
- x)
+ self.num_filters,
+ (7, 7),
+ (2, 2),
+ padding=[(3, 3), (3, 3)],
+ name='Conv_init',
+ )(x)
x = norm(name='BatchNorm_init')(x)
x = self.act(x)
x = nn.max_pool(x, (3, 3), strides=(2, 2), padding=((1, 1), (1, 1)))
@@ -110,23 +117,23 @@ def __call__(self,
for j in range(block_size):
strides = (2, 2) if i > 0 and j == 0 else (1, 1)
x = self.block_cls(
- self.num_filters * 2**i,
- strides=strides,
- conv=conv,
- norm=norm,
- act=self.act,
- bn_init_scale=self.bn_init_scale)(
- x)
+ self.num_filters * 2**i,
+ strides=strides,
+ conv=conv,
+ norm=norm,
+ act=self.act,
+ bn_init_scale=self.bn_init_scale,
+ )(x)
x = jnp.mean(x, axis=(1, 2))
x = nn.Dense(
- self.num_classes,
- kernel_init=nn.initializers.normal(),
- dtype=self.dtype)(
- x)
+ self.num_classes, kernel_init=nn.initializers.normal(), dtype=self.dtype
+ )(x)
return x
ResNet18 = functools.partial(
- ResNet, stage_sizes=(2, 2, 2, 2), block_cls=ResNetBlock)
+ ResNet, stage_sizes=(2, 2, 2, 2), block_cls=ResNetBlock
+)
ResNet50 = functools.partial(
- ResNet, stage_sizes=(3, 4, 6, 3), block_cls=BottleneckResNetBlock)
+ ResNet, stage_sizes=(3, 4, 6, 3), block_cls=BottleneckResNetBlock
+)
diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py
index c68e2de33..03b36e03d 100644
--- a/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py
+++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py
@@ -9,16 +9,19 @@
import tensorflow as tf
-from algoperf.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \
- rotate_img
-from algoperf.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \
- transform
-from algoperf.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import \
- translate
+from algoperf.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import (
+ rotate_img,
+)
+from algoperf.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import (
+ transform,
+)
+from algoperf.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import (
+ translate,
+)
# This signifies the max integer that the controller RNN could predict for the
# augmentation scheme.
-_MAX_LEVEL = 10.
+_MAX_LEVEL = 10.0
def blend(image1, image2, factor):
@@ -86,10 +89,12 @@ def cutout(image, pad_size, replace=0):
# Sample the center location in the image where the zero mask will be applied.
cutout_center_height = tf.random.uniform(
- shape=[], minval=0, maxval=image_height, dtype=tf.int32)
+ shape=[], minval=0, maxval=image_height, dtype=tf.int32
+ )
cutout_center_width = tf.random.uniform(
- shape=[], minval=0, maxval=image_width, dtype=tf.int32)
+ shape=[], minval=0, maxval=image_width, dtype=tf.int32
+ )
lower_pad = tf.maximum(0, cutout_center_height - pad_size)
upper_pad = tf.maximum(0, image_height - cutout_center_height - pad_size)
@@ -97,20 +102,18 @@ def cutout(image, pad_size, replace=0):
right_pad = tf.maximum(0, image_width - cutout_center_width - pad_size)
cutout_shape = [
- image_height - (lower_pad + upper_pad),
- image_width - (left_pad + right_pad),
+ image_height - (lower_pad + upper_pad),
+ image_width - (left_pad + right_pad),
]
padding_dims = [[lower_pad, upper_pad], [left_pad, right_pad]]
mask = tf.pad(
- tf.zeros(cutout_shape, dtype=image.dtype),
- padding_dims,
- constant_values=1)
+ tf.zeros(cutout_shape, dtype=image.dtype), padding_dims, constant_values=1
+ )
mask = tf.expand_dims(mask, -1)
mask = tf.tile(mask, [1, 1, 3])
image = tf.where(
- tf.equal(mask, 0),
- tf.ones_like(image, dtype=image.dtype) * replace,
- image)
+ tf.equal(mask, 0), tf.ones_like(image, dtype=image.dtype) * replace, image
+ )
return image
@@ -204,7 +207,7 @@ def shear_x(image, level, replace):
# with a matrix form of:
# [1 level
# 0 1].
- image = transform(wrap(image), [1., level, 0., 0., 1., 0., 0., 0.])
+ image = transform(wrap(image), [1.0, level, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0])
return unwrap(image, replace)
@@ -214,7 +217,7 @@ def shear_y(image, level, replace):
# with a matrix form of:
# [1 0
# level 1].
- image = transform(wrap(image), [1., 0., 0., level, 1., 0., 0., 0.])
+ image = transform(wrap(image), [1.0, 0.0, 0.0, level, 1.0, 0.0, 0.0, 0.0])
return unwrap(image, replace)
@@ -264,9 +267,12 @@ def sharpness(image, factor):
# Make image 4D for conv operation.
image = tf.expand_dims(image, 0)
# SMOOTH PIL Kernel.
- kernel = tf.constant([[1, 1, 1], [1, 5, 1], [1, 1, 1]],
- dtype=tf.float32,
- shape=[3, 3, 1, 1]) / 13.
+ kernel = (
+ tf.constant(
+ [[1, 1, 1], [1, 5, 1], [1, 1, 1]], dtype=tf.float32, shape=[3, 3, 1, 1]
+ )
+ / 13.0
+ )
# Tile across channel dimension.
kernel = tf.tile(kernel, [1, 1, 3, 1])
strides = [1, 1, 1, 1]
@@ -274,7 +280,8 @@ def sharpness(image, factor):
# Some augmentation that uses depth-wise conv will cause crashing when
# training on GPU.
degenerate = tf.nn.depthwise_conv2d(
- image, kernel, strides, padding='VALID', dilations=[1, 1])
+ image, kernel, strides, padding='VALID', dilations=[1, 1]
+ )
degenerate = tf.clip_by_value(degenerate, 0.0, 255.0)
degenerate = tf.squeeze(tf.cast(degenerate, tf.uint8), [0])
@@ -316,9 +323,10 @@ def build_lut(histo, step):
# If step is zero, return the original image. Otherwise, build
# lut from the full histogram and step and then index from it.
result = tf.cond(
- tf.equal(step, 0),
- lambda: im,
- lambda: tf.gather(build_lut(histo, step), im))
+ tf.equal(step, 0),
+ lambda: im,
+ lambda: tf.gather(build_lut(histo, step), im),
+ )
return tf.cast(result, tf.uint8)
@@ -373,9 +381,10 @@ def unwrap(image, replace):
# Where they are zero, fill them in with 'replace'.
flattened_image = tf.where(
- tf.equal(alpha_channel, 0),
- tf.ones_like(flattened_image, dtype=image.dtype) * replace,
- flattened_image)
+ tf.equal(alpha_channel, 0),
+ tf.ones_like(flattened_image, dtype=image.dtype) * replace,
+ flattened_image,
+ )
image = tf.reshape(flattened_image, image_shape)
image = tf.slice(image, [0, 0, 0], [image_shape[0], image_shape[1], 3])
@@ -383,22 +392,22 @@ def unwrap(image, replace):
NAME_TO_FUNC = {
- 'AutoContrast': autocontrast,
- 'Equalize': equalize,
- 'Invert': invert,
- 'Rotate': rotate,
- 'Posterize': posterize,
- 'Solarize': solarize,
- 'SolarizeAdd': solarize_add,
- 'Color': color,
- 'Contrast': contrast,
- 'Brightness': brightness,
- 'Sharpness': sharpness,
- 'ShearX': shear_x,
- 'ShearY': shear_y,
- 'TranslateX': translate_x,
- 'TranslateY': translate_y,
- 'Cutout': cutout,
+ 'AutoContrast': autocontrast,
+ 'Equalize': equalize,
+ 'Invert': invert,
+ 'Rotate': rotate,
+ 'Posterize': posterize,
+ 'Solarize': solarize,
+ 'SolarizeAdd': solarize_add,
+ 'Color': color,
+ 'Contrast': contrast,
+ 'Brightness': brightness,
+ 'Sharpness': sharpness,
+ 'ShearX': shear_x,
+ 'ShearY': shear_y,
+ 'TranslateX': translate_x,
+ 'TranslateY': translate_y,
+ 'Cutout': cutout,
}
@@ -410,7 +419,7 @@ def _randomly_negate_tensor(tensor):
def _rotate_level_to_arg(level):
- level = (level / _MAX_LEVEL) * 30.
+ level = (level / _MAX_LEVEL) * 30.0
level = _randomly_negate_tensor(level)
return (level,)
@@ -435,47 +444,28 @@ def _translate_level_to_arg(level, translate_const):
def level_to_arg(cutout_const, translate_const):
return {
- 'AutoContrast':
- lambda level: (),
- 'Equalize':
- lambda level: (),
- 'Invert':
- lambda level: (),
- 'Rotate':
- _rotate_level_to_arg,
- 'Posterize':
- lambda level: (int((level / _MAX_LEVEL) * 4),),
- 'Solarize':
- lambda level: (int((level / _MAX_LEVEL) * 256),),
- 'SolarizeAdd':
- lambda level: (int((level / _MAX_LEVEL) * 110),),
- 'Color':
- _enhance_level_to_arg,
- 'Contrast':
- _enhance_level_to_arg,
- 'Brightness':
- _enhance_level_to_arg,
- 'Sharpness':
- _enhance_level_to_arg,
- 'ShearX':
- _shear_level_to_arg,
- 'ShearY':
- _shear_level_to_arg,
- 'Cutout':
- lambda level: (int((level / _MAX_LEVEL) * cutout_const),),
- 'TranslateX':
- lambda level: _translate_level_to_arg(level, translate_const),
- 'TranslateY':
- lambda level: _translate_level_to_arg(level, translate_const),
+ 'AutoContrast': lambda level: (),
+ 'Equalize': lambda level: (),
+ 'Invert': lambda level: (),
+ 'Rotate': _rotate_level_to_arg,
+ 'Posterize': lambda level: (int((level / _MAX_LEVEL) * 4),),
+ 'Solarize': lambda level: (int((level / _MAX_LEVEL) * 256),),
+ 'SolarizeAdd': lambda level: (int((level / _MAX_LEVEL) * 110),),
+ 'Color': _enhance_level_to_arg,
+ 'Contrast': _enhance_level_to_arg,
+ 'Brightness': _enhance_level_to_arg,
+ 'Sharpness': _enhance_level_to_arg,
+ 'ShearX': _shear_level_to_arg,
+ 'ShearY': _shear_level_to_arg,
+ 'Cutout': lambda level: (int((level / _MAX_LEVEL) * cutout_const),),
+ 'TranslateX': lambda level: _translate_level_to_arg(level, translate_const),
+ 'TranslateY': lambda level: _translate_level_to_arg(level, translate_const),
}
-def _parse_policy_info(name,
- prob,
- level,
- replace_value,
- cutout_const,
- translate_const):
+def _parse_policy_info(
+ name, prob, level, replace_value, cutout_const, translate_const
+):
"""Return the function that corresponds to `name` and update `level` param."""
func = NAME_TO_FUNC[name]
args = level_to_arg(cutout_const, translate_const)[name](level)
@@ -514,45 +504,49 @@ def distort_image_with_randaugment(image, num_layers, magnitude, key):
"""
replace_value = [128] * 3
available_ops = [
- 'AutoContrast',
- 'Equalize',
- 'Invert',
- 'Rotate',
- 'Posterize',
- 'Solarize',
- 'Color',
- 'Contrast',
- 'Brightness',
- 'Sharpness',
- 'ShearX',
- 'ShearY',
- 'TranslateX',
- 'TranslateY',
- 'Cutout',
- 'SolarizeAdd',
+ 'AutoContrast',
+ 'Equalize',
+ 'Invert',
+ 'Rotate',
+ 'Posterize',
+ 'Solarize',
+ 'Color',
+ 'Contrast',
+ 'Brightness',
+ 'Sharpness',
+ 'ShearX',
+ 'ShearY',
+ 'TranslateX',
+ 'TranslateY',
+ 'Cutout',
+ 'SolarizeAdd',
]
for layer_num in range(num_layers):
key = tf.random.experimental.stateless_fold_in(key, layer_num)
- op_to_select = tf.random.stateless_uniform([],
- seed=key,
- maxval=len(available_ops),
- dtype=tf.int32)
+ op_to_select = tf.random.stateless_uniform(
+ [], seed=key, maxval=len(available_ops), dtype=tf.int32
+ )
random_magnitude = float(magnitude)
with tf.name_scope('randaug_layer_{}'.format(layer_num)):
- for (i, op_name) in enumerate(available_ops):
+ for i, op_name in enumerate(available_ops):
key = tf.random.experimental.stateless_fold_in(key, i)
- prob = tf.random.stateless_uniform([],
- seed=key,
- minval=0.2,
- maxval=0.8,
- dtype=tf.float32)
- func, _, args = _parse_policy_info(op_name, prob, random_magnitude,
- replace_value, cutout_const=40,
- translate_const=100)
+ prob = tf.random.stateless_uniform(
+ [], seed=key, minval=0.2, maxval=0.8, dtype=tf.float32
+ )
+ func, _, args = _parse_policy_info(
+ op_name,
+ prob,
+ random_magnitude,
+ replace_value,
+ cutout_const=40,
+ translate_const=100,
+ )
image = tf.cond(
- tf.equal(i, op_to_select),
- lambda selected_func=func,
- selected_args=args: selected_func(image, *selected_args),
- lambda: image)
+ tf.equal(i, op_to_select),
+ lambda selected_func=func, selected_args=args: selected_func(
+ image, *selected_args
+ ),
+ lambda: image,
+ )
return image
diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py
index 7896dcd05..c3035c212 100644
--- a/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py
+++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/workload.py
@@ -9,70 +9,75 @@
import math
from typing import Dict, Iterator, Optional, Tuple
-from flax import jax_utils
-from flax import linen as nn
-from flax.core import pop
import jax
-from jax import lax
import jax.numpy as jnp
import optax
import tensorflow_datasets as tfds
+from flax import jax_utils
+from flax import linen as nn
+from flax.core import pop
+from jax import lax
-from algoperf import param_utils
+from algoperf import param_utils, spec
from algoperf import random_utils as prng
-from algoperf import spec
from algoperf.workloads.imagenet_resnet import imagenet_v2
-from algoperf.workloads.imagenet_resnet.imagenet_jax import input_pipeline
-from algoperf.workloads.imagenet_resnet.imagenet_jax import models
-from algoperf.workloads.imagenet_resnet.workload import \
- BaseImagenetResNetWorkload
+from algoperf.workloads.imagenet_resnet.imagenet_jax import (
+ input_pipeline,
+ models,
+)
+from algoperf.workloads.imagenet_resnet.workload import (
+ BaseImagenetResNetWorkload,
+)
class ImagenetResNetWorkload(BaseImagenetResNetWorkload):
-
def _build_dataset(
- self,
- data_rng: spec.RandomState,
- split: str,
- data_dir: str,
- global_batch_size: int,
- cache: Optional[bool] = None,
- repeat_final_dataset: Optional[bool] = None,
- use_mixup: bool = False,
- use_randaug: bool = False) -> Iterator[Dict[str, spec.Tensor]]:
+ self,
+ data_rng: spec.RandomState,
+ split: str,
+ data_dir: str,
+ global_batch_size: int,
+ cache: Optional[bool] = None,
+ repeat_final_dataset: Optional[bool] = None,
+ use_mixup: bool = False,
+ use_randaug: bool = False,
+ ) -> Iterator[Dict[str, spec.Tensor]]:
if split == 'test':
np_iter = imagenet_v2.get_imagenet_v2_iter(
- data_dir,
- global_batch_size,
- mean_rgb=self.train_mean,
- stddev_rgb=self.train_stddev,
- image_size=self.center_crop_size,
- resize_size=self.resize_size)
+ data_dir,
+ global_batch_size,
+ mean_rgb=self.train_mean,
+ stddev_rgb=self.train_stddev,
+ image_size=self.center_crop_size,
+ resize_size=self.resize_size,
+ )
return itertools.cycle(np_iter)
ds_builder = tfds.builder('imagenet2012:5.1.0', data_dir=data_dir)
train = split == 'train'
ds = input_pipeline.create_input_iter(
- split,
- ds_builder,
- data_rng,
- global_batch_size,
- self.train_mean,
- self.train_stddev,
- self.center_crop_size,
- self.resize_size,
- self.aspect_ratio_range,
- self.scale_ratio_range,
- train=train,
- cache=not train if cache is None else cache,
- repeat_final_dataset=repeat_final_dataset,
- use_mixup=use_mixup,
- mixup_alpha=0.2,
- use_randaug=use_randaug)
+ split,
+ ds_builder,
+ data_rng,
+ global_batch_size,
+ self.train_mean,
+ self.train_stddev,
+ self.center_crop_size,
+ self.resize_size,
+ self.aspect_ratio_range,
+ self.scale_ratio_range,
+ train=train,
+ cache=not train if cache is None else cache,
+ repeat_final_dataset=repeat_final_dataset,
+ use_mixup=use_mixup,
+ mixup_alpha=0.2,
+ use_randaug=use_randaug,
+ )
return ds
def sync_batch_stats(
- self, model_state: spec.ModelAuxiliaryState) -> spec.ModelAuxiliaryState:
+ self, model_state: spec.ModelAuxiliaryState
+ ) -> spec.ModelAuxiliaryState:
"""Sync the batch statistics across replicas."""
# An axis_name is passed to pmap which can then be used by pmean.
# In this case each device has its own version of the batch statistics and
@@ -83,8 +88,8 @@ def sync_batch_stats(
return new_model_state
def init_model_fn(
- self,
- rng: spec.RandomState,
+ self,
+ rng: spec.RandomState,
) -> spec.ModelInitState:
model_cls = getattr(models, 'ResNet50')
@@ -98,15 +103,17 @@ def init_model_fn(
act_fnc = nn.relu
model = model_cls(
- num_classes=self._num_classes,
- act=act_fnc,
- bn_init_scale=self.bn_init_scale,
- dtype=jnp.float32)
+ num_classes=self._num_classes,
+ act=act_fnc,
+ bn_init_scale=self.bn_init_scale,
+ dtype=jnp.float32,
+ )
self._model = model
input_shape = (1, 224, 224, 3)
- variables = jax.jit(model.init)({'params': rng},
- jnp.ones(input_shape, model.dtype))
- model_state, params = pop(variables, "params")
+ variables = jax.jit(model.init)(
+ {'params': rng}, jnp.ones(input_shape, model.dtype)
+ )
+ model_state, params = pop(variables, 'params')
self._param_shapes = param_utils.jax_param_shapes(params)
self._param_types = param_utils.jax_param_types(self._param_shapes)
model_state = jax_utils.replicate(model_state)
@@ -117,63 +124,70 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key == 'Dense_0'
@functools.partial(
- jax.pmap,
- axis_name='batch',
- in_axes=(None, 0, 0, 0, 0),
- static_broadcasted_argnums=(0,))
- def _eval_model(self,
- params: spec.ParameterContainer,
- batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- rng: spec.RandomState) -> Dict[str, spec.Tensor]:
+ jax.pmap,
+ axis_name='batch',
+ in_axes=(None, 0, 0, 0, 0),
+ static_broadcasted_argnums=(0,),
+ )
+ def _eval_model(
+ self,
+ params: spec.ParameterContainer,
+ batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ rng: spec.RandomState,
+ ) -> Dict[str, spec.Tensor]:
logits, _ = self.model_fn(
- params,
- batch,
- model_state,
- spec.ForwardPassMode.EVAL,
- rng=rng,
- update_batch_norm=False)
+ params,
+ batch,
+ model_state,
+ spec.ForwardPassMode.EVAL,
+ rng=rng,
+ update_batch_norm=False,
+ )
weights = batch.get('weights')
return self._compute_metrics(logits, batch['targets'], weights)
def model_fn(
- self,
- params: spec.ParameterContainer,
- augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- mode: spec.ForwardPassMode,
- rng: spec.RandomState,
- update_batch_norm: bool,
- use_running_average_bn: Optional[bool] = None
+ self,
+ params: spec.ParameterContainer,
+ augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ mode: spec.ForwardPassMode,
+ rng: spec.RandomState,
+ update_batch_norm: bool,
+ use_running_average_bn: Optional[bool] = None,
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del mode
del rng
variables = {'params': params, **model_state}
if update_batch_norm:
logits, new_model_state = self._model.apply(
- variables,
- augmented_and_preprocessed_input_batch['inputs'],
- update_batch_norm=update_batch_norm,
- mutable=['batch_stats'],
- use_running_average_bn=use_running_average_bn)
+ variables,
+ augmented_and_preprocessed_input_batch['inputs'],
+ update_batch_norm=update_batch_norm,
+ mutable=['batch_stats'],
+ use_running_average_bn=use_running_average_bn,
+ )
return logits, new_model_state
else:
logits = self._model.apply(
- variables,
- augmented_and_preprocessed_input_batch['inputs'],
- update_batch_norm=update_batch_norm,
- mutable=False,
- use_running_average_bn=use_running_average_bn)
+ variables,
+ augmented_and_preprocessed_input_batch['inputs'],
+ update_batch_norm=update_batch_norm,
+ mutable=False,
+ use_running_average_bn=use_running_average_bn,
+ )
return logits, model_state
# Does NOT apply regularization, which is left to the submitter to do in
# `update_params`.
def loss_fn(
- self,
- label_batch: spec.Tensor, # Dense or one-hot labels.
- logits_batch: spec.Tensor,
- mask_batch: Optional[spec.Tensor] = None,
- label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable
+ self,
+ label_batch: spec.Tensor, # Dense or one-hot labels.
+ logits_batch: spec.Tensor,
+ mask_batch: Optional[spec.Tensor] = None,
+ label_smoothing: float = 0.0,
+ ) -> Dict[str, spec.Tensor]: # differentiable
"""Evaluate the (masked) loss function at (label_batch, logits_batch).
Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of
@@ -182,12 +196,14 @@ def loss_fn(
"""
if label_batch.shape[-1] != self._num_classes:
one_hot_labels = jax.nn.one_hot(
- label_batch, num_classes=self._num_classes)
+ label_batch, num_classes=self._num_classes
+ )
else:
one_hot_labels = label_batch
smoothed_labels = optax.smooth_labels(one_hot_labels, label_smoothing)
per_example_losses = -jnp.sum(
- smoothed_labels * jax.nn.log_softmax(logits_batch, axis=-1), axis=-1)
+ smoothed_labels * jax.nn.log_softmax(logits_batch, axis=-1), axis=-1
+ )
# mask_batch is assumed to be shape [batch].
if mask_batch is not None:
per_example_losses *= mask_batch
@@ -196,36 +212,37 @@ def loss_fn(
n_valid_examples = len(per_example_losses)
summed_loss = per_example_losses.sum()
return {
- 'summed': summed_loss,
- 'n_valid_examples': n_valid_examples,
- 'per_example': per_example_losses,
+ 'summed': summed_loss,
+ 'n_valid_examples': n_valid_examples,
+ 'per_example': per_example_losses,
}
- def _compute_metrics(self,
- logits: spec.Tensor,
- labels: spec.Tensor,
- weights: spec.Tensor) -> Dict[str, spec.Tensor]:
+ def _compute_metrics(
+ self, logits: spec.Tensor, labels: spec.Tensor, weights: spec.Tensor
+ ) -> Dict[str, spec.Tensor]:
if weights is None:
weights = jnp.ones(len(logits))
summed_loss = self.loss_fn(labels, logits, weights)['summed']
# not accuracy, but nr. of correct predictions
accuracy = jnp.sum((jnp.argmax(logits, -1) == labels) * weights)
metrics = {
- 'loss': summed_loss,
- 'accuracy': accuracy,
+ 'loss': summed_loss,
+ 'accuracy': accuracy,
}
metrics = lax.psum(metrics, axis_name='batch')
return metrics
- def _eval_model_on_split(self,
- split: str,
- num_examples: int,
- global_batch_size: int,
- params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- rng: spec.RandomState,
- data_dir: str,
- global_step: int = 0) -> Dict[str, float]:
+ def _eval_model_on_split(
+ self,
+ split: str,
+ num_examples: int,
+ global_batch_size: int,
+ params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ rng: spec.RandomState,
+ data_dir: str,
+ global_step: int = 0,
+ ) -> Dict[str, float]:
del global_step
if model_state is not None:
# Sync batch statistics across replicas before evaluating.
@@ -235,13 +252,14 @@ def _eval_model_on_split(self,
# We already repeat the dataset indefinitely in tf.data.
if split not in self._eval_iters:
self._eval_iters[split] = self._build_input_queue(
- data_rng,
- split=split,
- global_batch_size=global_batch_size,
- data_dir=data_dir,
- cache=True,
- repeat_final_dataset=True,
- num_batches=num_batches)
+ data_rng,
+ split=split,
+ global_batch_size=global_batch_size,
+ data_dir=data_dir,
+ cache=True,
+ repeat_final_dataset=True,
+ num_batches=num_batches,
+ )
eval_metrics = {}
for bi in range(num_batches):
@@ -249,22 +267,21 @@ def _eval_model_on_split(self,
step_eval_rngs = prng.split(eval_rng, jax.local_device_count())
batch = next(self._eval_iters[split])
# We already average these metrics across devices inside _compute_metrics.
- synced_metrics = self._eval_model(params,
- batch,
- model_state,
- step_eval_rngs)
+ synced_metrics = self._eval_model(
+ params, batch, model_state, step_eval_rngs
+ )
for metric_name, metric_value in synced_metrics.items():
if metric_name not in eval_metrics:
eval_metrics[metric_name] = 0.0
eval_metrics[metric_name] += metric_value
- eval_metrics = jax.tree.map(lambda x: float(x[0] / num_examples),
- eval_metrics)
+ eval_metrics = jax.tree.map(
+ lambda x: float(x[0] / num_examples), eval_metrics
+ )
return eval_metrics
class ImagenetResNetSiLUWorkload(ImagenetResNetWorkload):
-
@property
def use_silu(self) -> bool:
return True
@@ -279,7 +296,6 @@ def test_target_value(self) -> float:
class ImagenetResNetGELUWorkload(ImagenetResNetWorkload):
-
@property
def use_gelu(self) -> bool:
return True
@@ -294,7 +310,6 @@ def test_target_value(self) -> float:
class ImagenetResNetLargeBNScaleWorkload(ImagenetResNetWorkload):
-
@property
def bn_init_scale(self) -> float:
return 8.0
diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/models.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/models.py
index aba9e671f..ab3fc4a37 100644
--- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/models.py
+++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/models.py
@@ -15,44 +15,49 @@
from algoperf.init_utils import pytorch_default_init
-def conv3x3(in_planes: int,
- out_planes: int,
- stride: int = 1,
- groups: int = 1,
- dilation: int = 1) -> nn.Conv2d:
+def conv3x3(
+ in_planes: int,
+ out_planes: int,
+ stride: int = 1,
+ groups: int = 1,
+ dilation: int = 1,
+) -> nn.Conv2d:
"""3x3 convolution with padding."""
return nn.Conv2d(
- in_planes,
- out_planes,
- kernel_size=3,
- stride=stride,
- padding=dilation,
- groups=groups,
- bias=False,
- dilation=dilation)
+ in_planes,
+ out_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=dilation,
+ groups=groups,
+ bias=False,
+ dilation=dilation,
+ )
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
"""1x1 convolution."""
return nn.Conv2d(
- in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+ in_planes, out_planes, kernel_size=1, stride=stride, bias=False
+ )
class BasicBlock(nn.Module):
"""ResNet block."""
+
expansion: int = 1
def __init__(
- self,
- inplanes: int,
- planes: int,
- stride: int = 1,
- downsample: Optional[nn.Module] = None,
- groups: int = 1,
- base_width: int = 64,
- dilation: int = 1,
- norm_layer: Optional[Callable[..., nn.Module]] = None,
- act_fnc: nn.Module = nn.ReLU(inplace=True)
+ self,
+ inplanes: int,
+ planes: int,
+ stride: int = 1,
+ downsample: Optional[nn.Module] = None,
+ groups: int = 1,
+ base_width: int = 64,
+ dilation: int = 1,
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
+ act_fnc: nn.Module = nn.ReLU(inplace=True),
) -> None:
super().__init__()
if norm_layer is None:
@@ -60,7 +65,7 @@ def __init__(
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
- raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
+ raise NotImplementedError('Dilation > 1 not supported in BasicBlock')
# Both self.conv1 and self.downsample layers downsample
# the input when stride != 1.
self.conv1 = conv3x3(inplanes, planes, stride)
@@ -92,24 +97,25 @@ def forward(self, x: spec.Tensor) -> spec.Tensor:
class Bottleneck(nn.Module):
"""Bottleneck ResNet block."""
+
expansion: int = 4
def __init__(
- self,
- inplanes: int,
- planes: int,
- stride: int = 1,
- downsample: Optional[nn.Module] = None,
- groups: int = 1,
- base_width: int = 64,
- dilation: int = 1,
- norm_layer: Optional[Callable[..., nn.Module]] = None,
- act_fnc: nn.Module = nn.ReLU(inplace=True)
+ self,
+ inplanes: int,
+ planes: int,
+ stride: int = 1,
+ downsample: Optional[nn.Module] = None,
+ groups: int = 1,
+ base_width: int = 64,
+ dilation: int = 1,
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
+ act_fnc: nn.Module = nn.ReLU(inplace=True),
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
- width = int(planes * (base_width / 64.)) * groups
+ width = int(planes * (base_width / 64.0)) * groups
# Both self.conv2 and self.downsample layers downsample
# the input when stride != 1.
self.conv1 = conv1x1(inplanes, width)
@@ -146,18 +152,19 @@ def forward(self, x: Tensor) -> Tensor:
class ResNet(nn.Module):
-
- def __init__(self,
- block: Type[Union[BasicBlock, Bottleneck]],
- layers: List[int],
- num_classes: int = 1000,
- zero_init_residual: bool = True,
- groups: int = 1,
- width_per_group: int = 64,
- replace_stride_with_dilation: Optional[List[bool]] = None,
- norm_layer: Optional[Callable[..., nn.Module]] = None,
- act_fnc: nn.Module = nn.ReLU(inplace=True),
- bn_init_scale: float = 0.) -> None:
+ def __init__(
+ self,
+ block: Type[Union[BasicBlock, Bottleneck]],
+ layers: List[int],
+ num_classes: int = 1000,
+ zero_init_residual: bool = True,
+ groups: int = 1,
+ width_per_group: int = 64,
+ replace_stride_with_dilation: Optional[List[bool]] = None,
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
+ act_fnc: nn.Module = nn.ReLU(inplace=True),
+ bn_init_scale: float = 0.0,
+ ) -> None:
super().__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
@@ -171,37 +178,42 @@ def __init__(self,
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError(
- 'replace_stride_with_dilation should be None '
- f'or a 3-element tuple, got {replace_stride_with_dilation}')
+ 'replace_stride_with_dilation should be None '
+ f'or a 3-element tuple, got {replace_stride_with_dilation}'
+ )
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(
- 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
+ 3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False
+ )
self.bn1 = norm_layer(self.inplanes)
self.act_fnc = act_fnc
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, self.act_fnc, 64, layers[0])
self.layer2 = self._make_layer(
- block,
- self.act_fnc,
- 128,
- layers[1],
- stride=2,
- dilate=replace_stride_with_dilation[0])
+ block,
+ self.act_fnc,
+ 128,
+ layers[1],
+ stride=2,
+ dilate=replace_stride_with_dilation[0],
+ )
self.layer3 = self._make_layer(
- block,
- self.act_fnc,
- 256,
- layers[2],
- stride=2,
- dilate=replace_stride_with_dilation[1])
+ block,
+ self.act_fnc,
+ 256,
+ layers[2],
+ stride=2,
+ dilate=replace_stride_with_dilation[1],
+ )
self.layer4 = self._make_layer(
- block,
- self.act_fnc,
- 512,
- layers[3],
- stride=2,
- dilate=replace_stride_with_dilation[2])
+ block,
+ self.act_fnc,
+ 512,
+ layers[3],
+ stride=2,
+ dilate=replace_stride_with_dilation[2],
+ )
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
@@ -212,7 +224,7 @@ def __init__(self,
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
nn.init.normal_(self.fc.weight, std=1e-2)
- nn.init.constant_(self.fc.bias, 0.)
+ nn.init.constant_(self.fc.bias, 0.0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros,
@@ -226,13 +238,15 @@ def __init__(self,
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, bn_init_scale)
- def _make_layer(self,
- block: Type[Union[BasicBlock, Bottleneck]],
- act_fnc: nn.Module,
- planes: int,
- blocks: int,
- stride: int = 1,
- dilate: bool = False) -> nn.Sequential:
+ def _make_layer(
+ self,
+ block: Type[Union[BasicBlock, Bottleneck]],
+ act_fnc: nn.Module,
+ planes: int,
+ blocks: int,
+ stride: int = 1,
+ dilate: bool = False,
+ ) -> nn.Sequential:
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
@@ -241,34 +255,41 @@ def _make_layer(self,
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = torch.nn.Sequential(
- collections.OrderedDict([
- ("conv", conv1x1(self.inplanes, planes * block.expansion,
- stride)),
- ("bn", norm_layer(planes * block.expansion)),
- ]))
+ collections.OrderedDict(
+ [
+ ('conv', conv1x1(self.inplanes, planes * block.expansion, stride)),
+ ('bn', norm_layer(planes * block.expansion)),
+ ]
+ )
+ )
layers = []
layers.append(
- block(self.inplanes,
- planes,
- stride,
- downsample,
- self.groups,
- self.base_width,
- previous_dilation,
- norm_layer,
- act_fnc))
+ block(
+ self.inplanes,
+ planes,
+ stride,
+ downsample,
+ self.groups,
+ self.base_width,
+ previous_dilation,
+ norm_layer,
+ act_fnc,
+ )
+ )
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(
- block(
- self.inplanes,
- planes,
- groups=self.groups,
- base_width=self.base_width,
- dilation=self.dilation,
- norm_layer=norm_layer,
- act_fnc=act_fnc))
+ block(
+ self.inplanes,
+ planes,
+ groups=self.groups,
+ base_width=self.base_width,
+ dilation=self.dilation,
+ norm_layer=norm_layer,
+ act_fnc=act_fnc,
+ )
+ )
return nn.Sequential(*layers)
diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/randaugment.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/randaugment.py
index c7a98e77a..28ce00650 100644
--- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/randaugment.py
+++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/randaugment.py
@@ -24,8 +24,8 @@ def cutout(img: spec.Tensor, pad_size: int) -> spec.Tensor:
# Double the pad size to match Jax implementation.
pad_size = pad_size * 2
- x0 = int(max(0, x0 - pad_size / 2.))
- y0 = int(max(0, y0 - pad_size / 2.))
+ x0 = int(max(0, x0 - pad_size / 2.0))
+ y0 = int(max(0, y0 - pad_size / 2.0))
x1 = int(min(image_width, x0 + pad_size))
y1 = int(min(image_height, y0 + pad_size))
xy = (x0, y0, x1, y1)
@@ -36,7 +36,7 @@ def cutout(img: spec.Tensor, pad_size: int) -> spec.Tensor:
def solarize(img: spec.Tensor, threshold: float) -> spec.Tensor:
img = np.array(img)
- new_img = np.where(img < threshold, img, 255. - img)
+ new_img = np.where(img < threshold, img, 255.0 - img)
return PIL.Image.fromarray(new_img.astype(np.uint8))
@@ -49,54 +49,56 @@ def solarize_add(img: spec.Tensor, addition: int = 0) -> spec.Tensor:
return PIL.Image.fromarray(new_img)
-def _apply_op(img: spec.Tensor,
- op_name: str,
- magnitude: float,
- interpolation: InterpolationMode,
- fill: Optional[List[float]]) -> spec.Tensor:
+def _apply_op(
+ img: spec.Tensor,
+ op_name: str,
+ magnitude: float,
+ interpolation: InterpolationMode,
+ fill: Optional[List[float]],
+) -> spec.Tensor:
if op_name == 'ShearX':
# Magnitude should be arctan(magnitude).
img = F.affine(
- img,
- angle=0.0,
- translate=[0, 0],
- scale=1.0,
- shear=[math.degrees(math.atan(magnitude)), 0.0],
- interpolation=interpolation,
- fill=fill,
- center=[0, 0],
+ img,
+ angle=0.0,
+ translate=[0, 0],
+ scale=1.0,
+ shear=[math.degrees(math.atan(magnitude)), 0.0],
+ interpolation=interpolation,
+ fill=fill,
+ center=[0, 0],
)
elif op_name == 'ShearY':
# Magnitude should be arctan(magnitude).
img = F.affine(
- img,
- angle=0.0,
- translate=[0, 0],
- scale=1.0,
- shear=[0.0, math.degrees(math.atan(magnitude))],
- interpolation=interpolation,
- fill=fill,
- center=[0, 0],
+ img,
+ angle=0.0,
+ translate=[0, 0],
+ scale=1.0,
+ shear=[0.0, math.degrees(math.atan(magnitude))],
+ interpolation=interpolation,
+ fill=fill,
+ center=[0, 0],
)
elif op_name == 'TranslateX':
img = F.affine(
- img,
- angle=0.0,
- translate=[int(magnitude), 0],
- scale=1.0,
- interpolation=interpolation,
- shear=[0.0, 0.0],
- fill=fill,
+ img,
+ angle=0.0,
+ translate=[int(magnitude), 0],
+ scale=1.0,
+ interpolation=interpolation,
+ shear=[0.0, 0.0],
+ fill=fill,
)
elif op_name == 'TranslateY':
img = F.affine(
- img,
- angle=0.0,
- translate=[0, int(magnitude)],
- scale=1.0,
- interpolation=interpolation,
- shear=[0.0, 0.0],
- fill=fill,
+ img,
+ angle=0.0,
+ translate=[0, int(magnitude)],
+ scale=1.0,
+ interpolation=interpolation,
+ shear=[0.0, 0.0],
+ fill=fill,
)
elif op_name == 'Rotate':
img = F.rotate(img, magnitude, interpolation=interpolation, fill=fill)
@@ -131,33 +133,32 @@ def _apply_op(img: spec.Tensor,
def ops_space() -> Dict[str, Tuple[spec.Tensor, bool]]:
return {
- # op_name: (magnitudes, signed)
- 'ShearX': (torch.tensor(0.3), True),
- 'ShearY': (torch.tensor(0.3), True),
- 'TranslateX': (torch.tensor(100), True),
- 'TranslateY': (torch.tensor(100), True),
- 'Rotate': (torch.tensor(30), True),
- 'Brightness': (torch.tensor(1.9), False),
- 'Color': (torch.tensor(1.9), False),
- 'Contrast': (torch.tensor(1.9), False),
- 'Sharpness': (torch.tensor(1.9), False),
- 'Posterize': (torch.tensor(4), False),
- 'Solarize': (torch.tensor(256), False),
- 'SolarizeAdd': (torch.tensor(110), False),
- 'AutoContrast': (torch.tensor(0.0), False),
- 'Equalize': (torch.tensor(0.0), False),
- 'Invert': (torch.tensor(0.0), False),
- 'Cutout': (torch.tensor(40.0), False),
+ # op_name: (magnitudes, signed)
+ 'ShearX': (torch.tensor(0.3), True),
+ 'ShearY': (torch.tensor(0.3), True),
+ 'TranslateX': (torch.tensor(100), True),
+ 'TranslateY': (torch.tensor(100), True),
+ 'Rotate': (torch.tensor(30), True),
+ 'Brightness': (torch.tensor(1.9), False),
+ 'Color': (torch.tensor(1.9), False),
+ 'Contrast': (torch.tensor(1.9), False),
+ 'Sharpness': (torch.tensor(1.9), False),
+ 'Posterize': (torch.tensor(4), False),
+ 'Solarize': (torch.tensor(256), False),
+ 'SolarizeAdd': (torch.tensor(110), False),
+ 'AutoContrast': (torch.tensor(0.0), False),
+ 'Equalize': (torch.tensor(0.0), False),
+ 'Invert': (torch.tensor(0.0), False),
+ 'Cutout': (torch.tensor(40.0), False),
}
class RandAugment(torch.nn.Module):
-
def __init__(
- self,
- num_ops: int = 2,
- interpolation: InterpolationMode = InterpolationMode.NEAREST,
- fill: Optional[List[float]] = None,
+ self,
+ num_ops: int = 2,
+ interpolation: InterpolationMode = InterpolationMode.NEAREST,
+ fill: Optional[List[float]] = None,
) -> None:
super().__init__()
self.num_ops = num_ops
@@ -183,5 +184,6 @@ def forward(self, img: spec.Tensor) -> spec.Tensor:
# With 50% prob turn the magnitude negative.
magnitude *= -1.0
img = _apply_op(
- img, op_name, magnitude, interpolation=self.interpolation, fill=fill)
+ img, op_name, magnitude, interpolation=self.interpolation, fill=fill
+ )
return img
diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py
index 285ba3b4b..85a35dc45 100644
--- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py
+++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/workload.py
@@ -16,22 +16,21 @@
from torchvision import transforms
from torchvision.datasets.folder import ImageFolder
-from algoperf import data_utils
-from algoperf import param_utils
-from algoperf import pytorch_utils
-from algoperf import spec
import algoperf.random_utils as prng
+from algoperf import data_utils, param_utils, pytorch_utils, spec
from algoperf.workloads.imagenet_resnet import imagenet_v2
from algoperf.workloads.imagenet_resnet.imagenet_pytorch import randaugment
from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import resnet50
-from algoperf.workloads.imagenet_resnet.workload import \
- BaseImagenetResNetWorkload
+from algoperf.workloads.imagenet_resnet.workload import (
+ BaseImagenetResNetWorkload,
+)
USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup()
def imagenet_v2_to_torch(
- batch: Dict[str, spec.Tensor]) -> Dict[str, spec.Tensor]:
+ batch: Dict[str, spec.Tensor],
+) -> Dict[str, spec.Tensor]:
# Slice off the part of the batch for this device and then transpose from
# [N, H, W, C] to [N, C, H, W]. Only transfer the inputs to GPU.
new_batch = {}
@@ -48,7 +47,6 @@ def imagenet_v2_to_torch(
class ImagenetResNetWorkload(BaseImagenetResNetWorkload):
-
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
# Is set in submission_runner.py for workloads with PyTorch evaluation
@@ -59,7 +57,8 @@ def __init__(self, *args, **kwargs) -> None:
def eval_num_workers(self) -> int:
if self._eval_num_workers is None:
raise ValueError(
- 'eval_num_workers property must be set before workload is used.')
+ 'eval_num_workers property must be set before workload is used.'
+ )
return self._eval_num_workers
@eval_num_workers.setter
@@ -67,60 +66,68 @@ def eval_num_workers(self, eval_num_workers: int):
self._eval_num_workers = eval_num_workers
def _build_dataset(
- self,
- data_rng: spec.RandomState,
- split: str,
- data_dir: str,
- global_batch_size: int,
- cache: Optional[bool] = None,
- repeat_final_dataset: Optional[bool] = None,
- use_mixup: bool = False,
- use_randaug: bool = False) -> Iterator[Dict[str, spec.Tensor]]:
+ self,
+ data_rng: spec.RandomState,
+ split: str,
+ data_dir: str,
+ global_batch_size: int,
+ cache: Optional[bool] = None,
+ repeat_final_dataset: Optional[bool] = None,
+ use_mixup: bool = False,
+ use_randaug: bool = False,
+ ) -> Iterator[Dict[str, spec.Tensor]]:
del cache
del repeat_final_dataset
if split == 'test':
np_iter = imagenet_v2.get_imagenet_v2_iter(
- data_dir,
- global_batch_size,
- mean_rgb=self.train_mean,
- stddev_rgb=self.train_stddev,
- image_size=self.center_crop_size,
- resize_size=self.resize_size)
+ data_dir,
+ global_batch_size,
+ mean_rgb=self.train_mean,
+ stddev_rgb=self.train_stddev,
+ image_size=self.center_crop_size,
+ resize_size=self.resize_size,
+ )
return map(imagenet_v2_to_torch, itertools.cycle(np_iter))
is_train = split == 'train'
normalize = transforms.Normalize(
- mean=[i / 255. for i in self.train_mean],
- std=[i / 255. for i in self.train_stddev])
+ mean=[i / 255.0 for i in self.train_mean],
+ std=[i / 255.0 for i in self.train_stddev],
+ )
if is_train:
transform_config = [
- transforms.RandomResizedCrop(
- self.center_crop_size,
- scale=self.scale_ratio_range,
- ratio=self.aspect_ratio_range),
- transforms.RandomHorizontalFlip(),
+ transforms.RandomResizedCrop(
+ self.center_crop_size,
+ scale=self.scale_ratio_range,
+ ratio=self.aspect_ratio_range,
+ ),
+ transforms.RandomHorizontalFlip(),
]
if use_randaug:
transform_config.append(randaugment.RandAugment())
transform_config.extend([transforms.ToTensor(), normalize])
transform_config = transforms.Compose(transform_config)
else:
- transform_config = transforms.Compose([
+ transform_config = transforms.Compose(
+ [
transforms.Resize(self.resize_size),
transforms.CenterCrop(self.center_crop_size),
transforms.ToTensor(),
normalize,
- ])
+ ]
+ )
folder = 'train' if 'train' in split else 'val'
dataset = ImageFolder(
- os.path.join(data_dir, folder), transform=transform_config)
+ os.path.join(data_dir, folder), transform=transform_config
+ )
if split == 'eval_train':
indices = list(range(self.num_train_examples))
random.Random(int(data_rng[0])).shuffle(indices)
- dataset = torch.utils.data.Subset(dataset,
- indices[:self.num_eval_train_examples])
+ dataset = torch.utils.data.Subset(
+ dataset, indices[: self.num_eval_train_examples]
+ )
sampler = None
if USE_PYTORCH_DDP:
@@ -131,26 +138,30 @@ def _build_dataset(
if USE_PYTORCH_DDP:
if is_train:
sampler = torch.utils.data.distributed.DistributedSampler(
- dataset, num_replicas=N_GPUS, rank=RANK, shuffle=True)
+ dataset, num_replicas=N_GPUS, rank=RANK, shuffle=True
+ )
else:
sampler = data_utils.DistributedEvalSampler(
- dataset, num_replicas=N_GPUS, rank=RANK, shuffle=False)
+ dataset, num_replicas=N_GPUS, rank=RANK, shuffle=False
+ )
dataloader = torch.utils.data.DataLoader(
- dataset,
- batch_size=ds_iter_batch_size,
- shuffle=not USE_PYTORCH_DDP and is_train,
- sampler=sampler,
- num_workers=4 if is_train else self.eval_num_workers,
- pin_memory=True,
- drop_last=is_train,
- persistent_workers=is_train)
+ dataset,
+ batch_size=ds_iter_batch_size,
+ shuffle=not USE_PYTORCH_DDP and is_train,
+ sampler=sampler,
+ num_workers=4 if is_train else self.eval_num_workers,
+ pin_memory=True,
+ drop_last=is_train,
+ persistent_workers=is_train,
+ )
dataloader = data_utils.PrefetchedWrapper(dataloader, DEVICE)
dataloader = data_utils.cycle(
- dataloader,
- custom_sampler=USE_PYTORCH_DDP,
- use_mixup=use_mixup,
- mixup_alpha=0.2)
+ dataloader,
+ custom_sampler=USE_PYTORCH_DDP,
+ use_mixup=use_mixup,
+ mixup_alpha=0.2,
+ )
return dataloader
@@ -181,14 +192,14 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key in ['fc.weight', 'fc.bias']
def model_fn(
- self,
- params: spec.ParameterContainer,
- augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- mode: spec.ForwardPassMode,
- rng: spec.RandomState,
- update_batch_norm: bool,
- dropout_rate: float = 0.0
+ self,
+ params: spec.ParameterContainer,
+ augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ mode: spec.ForwardPassMode,
+ rng: spec.RandomState,
+ update_batch_norm: bool,
+ dropout_rate: float = 0.0,
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del rng
@@ -199,19 +210,22 @@ def model_fn(
if mode == spec.ForwardPassMode.EVAL:
if update_batch_norm:
raise ValueError(
- 'Batch norm statistics cannot be updated during evaluation.')
+ 'Batch norm statistics cannot be updated during evaluation.'
+ )
model.eval()
if mode == spec.ForwardPassMode.TRAIN:
model.train()
model.apply(
- functools.partial(
- pytorch_utils.update_batch_norm_fn,
- update_batch_norm=update_batch_norm))
+ functools.partial(
+ pytorch_utils.update_batch_norm_fn,
+ update_batch_norm=update_batch_norm,
+ )
+ )
contexts = {
- spec.ForwardPassMode.EVAL: torch.no_grad,
- spec.ForwardPassMode.TRAIN: contextlib.nullcontext,
+ spec.ForwardPassMode.EVAL: torch.no_grad,
+ spec.ForwardPassMode.TRAIN: contextlib.nullcontext,
}
with contexts[mode]():
@@ -222,11 +236,12 @@ def model_fn(
# Does NOT apply regularization, which is left to the submitter to do in
# `update_params`.
def loss_fn(
- self,
- label_batch: spec.Tensor, # Dense or one-hot labels.
- logits_batch: spec.Tensor,
- mask_batch: Optional[spec.Tensor] = None,
- label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable
+ self,
+ label_batch: spec.Tensor, # Dense or one-hot labels.
+ logits_batch: spec.Tensor,
+ mask_batch: Optional[spec.Tensor] = None,
+ label_smoothing: float = 0.0,
+ ) -> Dict[str, spec.Tensor]: # differentiable
"""Evaluate the (masked) loss function at (label_batch, logits_batch).
Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of
@@ -234,10 +249,11 @@ def loss_fn(
(not synced across devices).
"""
per_example_losses = F.cross_entropy(
- logits_batch,
- label_batch,
- reduction='none',
- label_smoothing=label_smoothing)
+ logits_batch,
+ label_batch,
+ reduction='none',
+ label_smoothing=label_smoothing,
+ )
# `mask_batch` is assumed to be shape [batch].
if mask_batch is not None:
per_example_losses *= mask_batch
@@ -246,15 +262,14 @@ def loss_fn(
n_valid_examples = len(per_example_losses)
summed_loss = per_example_losses.sum()
return {
- 'summed': summed_loss,
- 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE),
- 'per_example': per_example_losses,
+ 'summed': summed_loss,
+ 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE),
+ 'per_example': per_example_losses,
}
- def _compute_metrics(self,
- logits: spec.Tensor,
- labels: spec.Tensor,
- weights: spec.Tensor) -> Dict[str, spec.Tensor]:
+ def _compute_metrics(
+ self, logits: spec.Tensor, labels: spec.Tensor, weights: spec.Tensor
+ ) -> Dict[str, spec.Tensor]:
"""Return the mean accuracy and loss as a dict."""
if weights is None:
weights = torch.ones(len(logits), device=DEVICE)
@@ -264,15 +279,17 @@ def _compute_metrics(self,
summed_loss = self.loss_fn(labels, logits, weights)['summed']
return {'accuracy': accuracy, 'loss': summed_loss}
- def _eval_model_on_split(self,
- split: str,
- num_examples: int,
- global_batch_size: int,
- params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- rng: spec.RandomState,
- data_dir: str,
- global_step: int = 0) -> Dict[str, float]:
+ def _eval_model_on_split(
+ self,
+ split: str,
+ num_examples: int,
+ global_batch_size: int,
+ params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ rng: spec.RandomState,
+ data_dir: str,
+ global_step: int = 0,
+ ) -> Dict[str, float]:
"""Run a full evaluation of the model."""
del global_step
data_rng, model_rng = prng.split(rng, 2)
@@ -280,31 +297,33 @@ def _eval_model_on_split(self,
is_test = split == 'test'
# These iterators repeat indefinitely.
self._eval_iters[split] = self._build_input_queue(
- data_rng,
- split=split,
- global_batch_size=global_batch_size,
- data_dir=data_dir,
- cache=is_test,
- repeat_final_dataset=is_test)
+ data_rng,
+ split=split,
+ global_batch_size=global_batch_size,
+ data_dir=data_dir,
+ cache=is_test,
+ repeat_final_dataset=is_test,
+ )
total_metrics = {
- 'accuracy': torch.tensor(0., device=DEVICE),
- 'loss': torch.tensor(0., device=DEVICE),
+ 'accuracy': torch.tensor(0.0, device=DEVICE),
+ 'loss': torch.tensor(0.0, device=DEVICE),
}
num_batches = int(math.ceil(num_examples / global_batch_size))
for _ in range(num_batches):
batch = next(self._eval_iters[split])
logits, _ = self.model_fn(
- params,
- batch,
- model_state,
- spec.ForwardPassMode.EVAL,
- model_rng,
- update_batch_norm=False)
+ params,
+ batch,
+ model_state,
+ spec.ForwardPassMode.EVAL,
+ model_rng,
+ update_batch_norm=False,
+ )
weights = batch.get('weights')
batch_metrics = self._compute_metrics(logits, batch['targets'], weights)
total_metrics = {
- k: v + batch_metrics[k] for k, v in total_metrics.items()
+ k: v + batch_metrics[k] for k, v in total_metrics.items()
}
if USE_PYTORCH_DDP:
for metric in total_metrics.values():
@@ -313,7 +332,6 @@ def _eval_model_on_split(self,
class ImagenetResNetSiLUWorkload(ImagenetResNetWorkload):
-
@property
def use_silu(self) -> bool:
return True
@@ -328,7 +346,6 @@ def test_target_value(self) -> float:
class ImagenetResNetGELUWorkload(ImagenetResNetWorkload):
-
@property
def use_gelu(self) -> bool:
return True
@@ -343,7 +360,6 @@ def test_target_value(self) -> float:
class ImagenetResNetLargeBNScaleWorkload(ImagenetResNetWorkload):
-
@property
def bn_init_scale(self) -> float:
return 8.0
diff --git a/algoperf/workloads/imagenet_resnet/imagenet_v2.py b/algoperf/workloads/imagenet_resnet/imagenet_v2.py
index 84d364586..6ffb73367 100644
--- a/algoperf/workloads/imagenet_resnet/imagenet_v2.py
+++ b/algoperf/workloads/imagenet_resnet/imagenet_v2.py
@@ -13,32 +13,34 @@
from algoperf.workloads.imagenet_resnet.imagenet_jax import input_pipeline
-def get_imagenet_v2_iter(data_dir: str,
- global_batch_size: int,
- mean_rgb: Tuple[float, float, float],
- stddev_rgb: Tuple[float, float, float],
- image_size: int,
- resize_size: int) -> Iterator[Dict[str, spec.Tensor]]:
+def get_imagenet_v2_iter(
+ data_dir: str,
+ global_batch_size: int,
+ mean_rgb: Tuple[float, float, float],
+ stddev_rgb: Tuple[float, float, float],
+ image_size: int,
+ resize_size: int,
+) -> Iterator[Dict[str, spec.Tensor]]:
"""Always caches and repeats indefinitely."""
ds = tfds.load(
- 'imagenet_v2/matched-frequency:3.0.0',
- split='test',
- data_dir=data_dir,
- decoders={
- 'image': tfds.decode.SkipDecoding(),
- })
+ 'imagenet_v2/matched-frequency:3.0.0',
+ split='test',
+ data_dir=data_dir,
+ decoders={
+ 'image': tfds.decode.SkipDecoding(),
+ },
+ )
def _decode_example(example: Dict[str, float]) -> Dict[str, float]:
- image = input_pipeline.preprocess_for_eval(example['image'],
- mean_rgb,
- stddev_rgb,
- image_size,
- resize_size)
+ image = input_pipeline.preprocess_for_eval(
+ example['image'], mean_rgb, stddev_rgb, image_size, resize_size
+ )
return {'inputs': image, 'targets': example['label']}
ds = ds.map(_decode_example, num_parallel_calls=16)
ds = ds.batch(global_batch_size)
shard_pad_fn = functools.partial(
- data_utils.shard_and_maybe_pad_np, global_batch_size=global_batch_size)
+ data_utils.shard_and_maybe_pad_np, global_batch_size=global_batch_size
+ )
it = map(shard_pad_fn, iter(ds))
return it
diff --git a/algoperf/workloads/imagenet_resnet/workload.py b/algoperf/workloads/imagenet_resnet/workload.py
index 83fe97108..ef696e328 100644
--- a/algoperf/workloads/imagenet_resnet/workload.py
+++ b/algoperf/workloads/imagenet_resnet/workload.py
@@ -7,7 +7,6 @@
class BaseImagenetResNetWorkload(spec.Workload):
-
_num_classes: int = 1000
@property
@@ -15,8 +14,9 @@ def target_metric_name(self) -> str:
"""The name of the target metric (useful for scoring/processing code)."""
return 'accuracy'
- def has_reached_validation_target(self, eval_result: Dict[str,
- float]) -> bool:
+ def has_reached_validation_target(
+ self, eval_result: Dict[str, float]
+ ) -> bool:
return eval_result['validation/accuracy'] > self.validation_target_value
@property
@@ -58,8 +58,9 @@ def num_eval_train_examples(self) -> int:
# Round up from num_validation_examples (which is the default for
# num_eval_train_examples) to the next multiple of eval_batch_size, so that
# we don't have to extract the correctly sized subset of the training data.
- rounded_up_multiple = math.ceil(self.num_validation_examples /
- self.eval_batch_size)
+ rounded_up_multiple = math.ceil(
+ self.num_validation_examples / self.eval_batch_size
+ )
return rounded_up_multiple * self.eval_batch_size
@property
@@ -109,38 +110,37 @@ def eval_period_time_sec(self) -> int:
return 510 # 8.5 minutes.
def _build_dataset(
- self,
- data_rng: spec.RandomState,
- split: str,
- data_dir: str,
- global_batch_size: int,
- cache: Optional[bool] = None,
- repeat_final_dataset: Optional[bool] = None,
- use_mixup: bool = False,
- use_randaug: bool = False) -> Iterator[Dict[str, spec.Tensor]]:
+ self,
+ data_rng: spec.RandomState,
+ split: str,
+ data_dir: str,
+ global_batch_size: int,
+ cache: Optional[bool] = None,
+ repeat_final_dataset: Optional[bool] = None,
+ use_mixup: bool = False,
+ use_randaug: bool = False,
+ ) -> Iterator[Dict[str, spec.Tensor]]:
raise NotImplementedError
def _build_input_queue(
- self,
- data_rng: spec.RandomState,
- split: str,
- data_dir: str,
- global_batch_size: int,
- cache: Optional[bool] = None,
- repeat_final_dataset: Optional[bool] = None,
- num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]:
+ self,
+ data_rng: spec.RandomState,
+ split: str,
+ data_dir: str,
+ global_batch_size: int,
+ cache: Optional[bool] = None,
+ repeat_final_dataset: Optional[bool] = None,
+ num_batches: Optional[int] = None,
+ ) -> Iterator[Dict[str, spec.Tensor]]:
del num_batches
if split == 'test':
if not cache:
raise ValueError('cache must be True for split=test.')
if not repeat_final_dataset:
raise ValueError('repeat_final_dataset must be True for split=test.')
- return self._build_dataset(data_rng,
- split,
- data_dir,
- global_batch_size,
- cache,
- repeat_final_dataset)
+ return self._build_dataset(
+ data_rng, split, data_dir, global_batch_size, cache, repeat_final_dataset
+ )
@property
def step_hint(self) -> int:
diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py
index f33dea723..5e38acd8b 100644
--- a/algoperf/workloads/imagenet_vit/imagenet_jax/models.py
+++ b/algoperf/workloads/imagenet_vit/imagenet_jax/models.py
@@ -7,8 +7,8 @@
from typing import Optional, Sequence, Union
-from flax import linen as nn
import jax.numpy as jnp
+from flax import linen as nn
from algoperf import spec
from algoperf.jax_utils import Dropout
@@ -16,18 +16,20 @@
DROPOUT_RATE = 0.0
-def posemb_sincos_2d(h: int,
- w: int,
- width: int,
- temperature: int = 10_000.,
- dtype: jnp.dtype = jnp.float32) -> spec.Tensor:
+def posemb_sincos_2d(
+ h: int,
+ w: int,
+ width: int,
+ temperature: int = 10_000.0,
+ dtype: jnp.dtype = jnp.float32,
+) -> spec.Tensor:
"""Follows the MoCo v3 logic."""
- y, x = jnp.mgrid[:h, :w] #pylint: disable=unpacking-non-sequence
+ y, x = jnp.mgrid[:h, :w] # pylint: disable=unpacking-non-sequence
if width % 4 != 0:
raise ValueError('Width must be mult of 4 for sincos posemb.')
omega = jnp.arange(width // 4) / (width // 4 - 1)
- omega = 1. / (temperature**omega)
+ omega = 1.0 / (temperature**omega)
y = jnp.einsum('m,d->md', y.flatten(), omega)
x = jnp.einsum('m,d->md', x.flatten(), omega)
pe = jnp.concatenate([jnp.sin(x), jnp.cos(x), jnp.sin(y), jnp.cos(y)], axis=1)
@@ -36,19 +38,19 @@ def posemb_sincos_2d(h: int,
class MlpBlock(nn.Module):
"""Transformer MLP / feed-forward block."""
+
mlp_dim: Optional[int] = None # Defaults to 4x input dim.
use_glu: bool = False
dropout_rate: float = DROPOUT_RATE
@nn.compact
- def __call__(self,
- x: spec.Tensor,
- train: bool = True,
- dropout_rate=DROPOUT_RATE) -> spec.Tensor:
+ def __call__(
+ self, x: spec.Tensor, train: bool = True, dropout_rate=DROPOUT_RATE
+ ) -> spec.Tensor:
"""Applies Transformer MlpBlock module."""
inits = {
- 'kernel_init': nn.initializers.xavier_uniform(),
- 'bias_init': nn.initializers.normal(stddev=1e-6),
+ 'kernel_init': nn.initializers.xavier_uniform(),
+ 'bias_init': nn.initializers.normal(stddev=1e-6),
}
d = x.shape[2]
@@ -66,6 +68,7 @@ def __call__(self,
class Encoder1DBlock(nn.Module):
"""Single transformer encoder block (MHSA + MLP)."""
+
mlp_dim: Optional[int] = None # Defaults to 4x input dim.
num_heads: int = 12
use_glu: bool = False
@@ -73,47 +76,45 @@ class Encoder1DBlock(nn.Module):
dropout_rate: float = 0.0
@nn.compact
- def __call__(self,
- x: spec.Tensor,
- train: bool = True,
- dropout_rate=dropout_rate) -> spec.Tensor:
-
+ def __call__(
+ self, x: spec.Tensor, train: bool = True, dropout_rate=dropout_rate
+ ) -> spec.Tensor:
if not self.use_post_layer_norm:
y = nn.LayerNorm(name='LayerNorm_0')(x)
y = nn.MultiHeadDotProductAttention(
- num_heads=self.num_heads,
- kernel_init=nn.initializers.xavier_uniform(),
- deterministic=train,
- name='MultiHeadDotProductAttention_1')(
- y)
+ num_heads=self.num_heads,
+ kernel_init=nn.initializers.xavier_uniform(),
+ deterministic=train,
+ name='MultiHeadDotProductAttention_1',
+ )(y)
y = Dropout(dropout_rate)(y, train, rate=dropout_rate)
x = x + y
y = nn.LayerNorm(name='LayerNorm_2')(x)
y = MlpBlock(
- mlp_dim=self.mlp_dim, use_glu=self.use_glu, name='MlpBlock_3')(
- y, train, dropout_rate=dropout_rate)
+ mlp_dim=self.mlp_dim, use_glu=self.use_glu, name='MlpBlock_3'
+ )(y, train, dropout_rate=dropout_rate)
y = Dropout(dropout_rate)(y, train, rate=dropout_rate)
x = x + y
else:
y = x
y = nn.MultiHeadDotProductAttention(
- num_heads=self.num_heads,
- kernel_init=nn.initializers.xavier_uniform(),
- deterministic=train,
- name='MultiHeadDotProductAttention_1')(
- y)
+ num_heads=self.num_heads,
+ kernel_init=nn.initializers.xavier_uniform(),
+ deterministic=train,
+ name='MultiHeadDotProductAttention_1',
+ )(y)
y = Dropout(dropout_rate)(y, train, rate=dropout_rate)
x = x + y
x = nn.LayerNorm(name='LayerNorm_0')(x)
y = x
y = MlpBlock(
- mlp_dim=self.mlp_dim,
- use_glu=self.use_glu,
- name='MlpBlock_3',
- dropout_rate=dropout_rate)(
- y, train, dropout_rate=dropout_rate)
+ mlp_dim=self.mlp_dim,
+ use_glu=self.use_glu,
+ name='MlpBlock_3',
+ dropout_rate=dropout_rate,
+ )(y, train, dropout_rate=dropout_rate)
y = Dropout(dropout_rate)(y, train)(rate=dropout_rate)
x = x + y
x = nn.LayerNorm(name='LayerNorm_2')(x)
@@ -131,27 +132,27 @@ class Encoder(nn.Module):
use_post_layer_norm: bool = False
@nn.compact
- def __call__(self,
- x: spec.Tensor,
- train: bool = True,
- dropout_rate: float = DROPOUT_RATE) -> spec.Tensor:
+ def __call__(
+ self, x: spec.Tensor, train: bool = True, dropout_rate: float = DROPOUT_RATE
+ ) -> spec.Tensor:
# Input Encoder
for lyr in range(self.depth):
x = Encoder1DBlock(
- name=f"encoderblock_{lyr}",
- mlp_dim=self.mlp_dim,
- num_heads=self.num_heads,
- use_glu=self.use_glu,
- use_post_layer_norm=self.use_post_layer_norm,
+ name=f'encoderblock_{lyr}',
+ mlp_dim=self.mlp_dim,
+ num_heads=self.num_heads,
+ use_glu=self.use_glu,
+ use_post_layer_norm=self.use_post_layer_norm,
)(x, train=train, dropout_rate=dropout_rate)
if not self.use_post_layer_norm:
- return nn.LayerNorm(name="encoder_layernorm")(x)
+ return nn.LayerNorm(name='encoder_layernorm')(x)
else:
return x
class MAPHead(nn.Module):
"""Multihead Attention Pooling."""
+
mlp_dim: Optional[int] = None # Defaults to 4x input dim
num_heads: int = 12
dropout_rate: float = 0.0
@@ -159,15 +160,16 @@ class MAPHead(nn.Module):
@nn.compact
def __call__(self, x, dropout_rate=DROPOUT_RATE):
n, _, d = x.shape
- probe = self.param('probe',
- nn.initializers.xavier_uniform(), (1, 1, d),
- x.dtype)
+ probe = self.param(
+ 'probe', nn.initializers.xavier_uniform(), (1, 1, d), x.dtype
+ )
probe = jnp.tile(probe, [n, 1, 1])
x = nn.MultiHeadDotProductAttention(
- num_heads=self.num_heads,
- use_bias=True,
- kernel_init=nn.initializers.xavier_uniform())(probe, x)
+ num_heads=self.num_heads,
+ use_bias=True,
+ kernel_init=nn.initializers.xavier_uniform(),
+ )(probe, x)
y = nn.LayerNorm()(x)
x = x + MlpBlock(mlp_dim=self.mlp_dim, dropout_rate=dropout_rate)(y)
@@ -191,26 +193,23 @@ class ViT(nn.Module):
use_post_layer_norm: bool = False
use_map: bool = False
- def get_posemb(self,
- seqshape: tuple,
- width: int,
- dtype: jnp.dtype = jnp.float32) -> spec.Tensor:
+ def get_posemb(
+ self, seqshape: tuple, width: int, dtype: jnp.dtype = jnp.float32
+ ) -> spec.Tensor:
return posemb_sincos_2d(*seqshape, width, dtype=dtype)
@nn.compact
- def __call__(self,
- x: spec.Tensor,
- *,
- train: bool = False,
- dropout_rate=DROPOUT_RATE) -> spec.Tensor:
+ def __call__(
+ self, x: spec.Tensor, *, train: bool = False, dropout_rate=DROPOUT_RATE
+ ) -> spec.Tensor:
# Patch extraction
x = nn.Conv(
- self.width,
- self.patch_size,
- strides=self.patch_size,
- padding='VALID',
- name='conv_patch_extract')(
- x)
+ self.width,
+ self.patch_size,
+ strides=self.patch_size,
+ padding='VALID',
+ name='conv_patch_extract',
+ )(x)
n, h, w, c = x.shape
x = jnp.reshape(x, [n, h * w, c])
@@ -221,20 +220,20 @@ def __call__(self,
x = Dropout(dropout_rate)(x, not train, rate=dropout_rate)
x = Encoder(
- depth=self.depth,
- mlp_dim=self.mlp_dim,
- num_heads=self.num_heads,
- use_glu=self.use_glu,
- use_post_layer_norm=self.use_post_layer_norm,
- name='Transformer',
+ depth=self.depth,
+ mlp_dim=self.mlp_dim,
+ num_heads=self.num_heads,
+ use_glu=self.use_glu,
+ use_post_layer_norm=self.use_post_layer_norm,
+ name='Transformer',
)(x, train=not train, dropout_rate=dropout_rate)
if self.use_map:
x = MAPHead(
- num_heads=self.num_heads,
- mlp_dim=self.mlp_dim,
- dropout_rate=dropout_rate)(
- x, dropout_rate=dropout_rate)
+ num_heads=self.num_heads,
+ mlp_dim=self.mlp_dim,
+ dropout_rate=dropout_rate,
+ )(x, dropout_rate=dropout_rate)
else:
x = jnp.mean(x, axis=1)
diff --git a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py
index 0e320b9b9..1637a2123 100644
--- a/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py
+++ b/algoperf/workloads/imagenet_vit/imagenet_jax/workload.py
@@ -2,40 +2,44 @@
from typing import Dict, Optional, Tuple
+import jax
+import jax.numpy as jnp
from flax import jax_utils
from flax import linen as nn
from flax.core import pop
-import jax
-import jax.numpy as jnp
-from algoperf import param_utils
-from algoperf import spec
-from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import \
- ImagenetResNetWorkload
+from algoperf import param_utils, spec
+from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import (
+ ImagenetResNetWorkload,
+)
from algoperf.workloads.imagenet_vit.imagenet_jax import models
-from algoperf.workloads.imagenet_vit.workload import BaseImagenetVitWorkload
-from algoperf.workloads.imagenet_vit.workload import decode_variant
+from algoperf.workloads.imagenet_vit.workload import (
+ BaseImagenetVitWorkload,
+ decode_variant,
+)
# Make sure we inherit from the ViT base workload first.
class ImagenetVitWorkload(BaseImagenetVitWorkload, ImagenetResNetWorkload):
-
- def initialized(self, key: spec.RandomState,
- model: nn.Module) -> spec.ModelInitState:
+ def initialized(
+ self, key: spec.RandomState, model: nn.Module
+ ) -> spec.ModelInitState:
input_shape = (1, 224, 224, 3)
params_rng, _ = jax.random.split(key)
- variables = jax.jit(model.init)({'params': params_rng},
- jnp.ones(input_shape))
- model_state, params = pop(variables, "params")
+ variables = jax.jit(model.init)(
+ {'params': params_rng}, jnp.ones(input_shape)
+ )
+ model_state, params = pop(variables, 'params')
return params, model_state
def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
self._model = models.ViT(
- num_classes=self._num_classes,
- use_glu=self.use_glu,
- use_post_layer_norm=self.use_post_layer_norm,
- use_map=self.use_map,
- **decode_variant('S/16'))
+ num_classes=self._num_classes,
+ use_glu=self.use_glu,
+ use_post_layer_norm=self.use_post_layer_norm,
+ use_map=self.use_map,
+ **decode_variant('S/16'),
+ )
params, model_state = self.initialized(rng, self._model)
self._param_shapes = param_utils.jax_param_shapes(params)
self._param_types = param_utils.jax_param_types(self._param_shapes)
@@ -47,49 +51,54 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key == 'head'
def model_fn(
- self,
- params: spec.ParameterContainer,
- augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- mode: spec.ForwardPassMode,
- rng: spec.RandomState,
- update_batch_norm: bool,
- use_running_average_bn: Optional[bool] = None,
- dropout_rate: float = models.DROPOUT_RATE
+ self,
+ params: spec.ParameterContainer,
+ augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ mode: spec.ForwardPassMode,
+ rng: spec.RandomState,
+ update_batch_norm: bool,
+ use_running_average_bn: Optional[bool] = None,
+ dropout_rate: float = models.DROPOUT_RATE,
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del update_batch_norm
del use_running_average_bn
train = mode == spec.ForwardPassMode.TRAIN
- logits = self._model.apply({'params': params},
- augmented_and_preprocessed_input_batch['inputs'],
- rngs={'dropout': rng},
- train=train,
- dropout_rate=dropout_rate)
+ logits = self._model.apply(
+ {'params': params},
+ augmented_and_preprocessed_input_batch['inputs'],
+ rngs={'dropout': rng},
+ train=train,
+ dropout_rate=dropout_rate,
+ )
return logits, None
- def _eval_model_on_split(self,
- split: str,
- num_examples: int,
- global_batch_size: int,
- params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- rng: spec.RandomState,
- data_dir: str,
- global_step: int = 0) -> Dict[str, float]:
+ def _eval_model_on_split(
+ self,
+ split: str,
+ num_examples: int,
+ global_batch_size: int,
+ params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ rng: spec.RandomState,
+ data_dir: str,
+ global_step: int = 0,
+ ) -> Dict[str, float]:
model_state = None
- return super()._eval_model_on_split(split,
- num_examples,
- global_batch_size,
- params,
- model_state,
- rng,
- data_dir,
- global_step)
+ return super()._eval_model_on_split(
+ split,
+ num_examples,
+ global_batch_size,
+ params,
+ model_state,
+ rng,
+ data_dir,
+ global_step,
+ )
class ImagenetVitGluWorkload(ImagenetVitWorkload):
-
@property
def use_glu(self) -> bool:
return True
@@ -104,7 +113,6 @@ def test_target_value(self) -> float:
class ImagenetVitPostLNWorkload(ImagenetVitWorkload):
-
@property
def use_post_layer_norm(self) -> bool:
return True
@@ -119,7 +127,6 @@ def test_target_value(self) -> float:
class ImagenetVitMapWorkload(ImagenetVitWorkload):
-
@property
def use_map(self) -> bool:
return True
diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py
index cb503cd9f..fc2a3cd46 100644
--- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py
+++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/models.py
@@ -9,27 +9,29 @@
from typing import Any, Optional, Tuple, Union
import torch
-from torch import nn
import torch.nn.functional as F
+from torch import nn
-from algoperf import init_utils
-from algoperf import spec
+from algoperf import init_utils, spec
from algoperf.workloads.wmt.wmt_pytorch.models import MultiheadAttention
DROPOUT_RATE = 0.0
-def posemb_sincos_2d(patches: spec.Tensor, temperature=10_000.) -> spec.Tensor:
+def posemb_sincos_2d(patches: spec.Tensor, temperature=10_000.0) -> spec.Tensor:
"""Follows the MoCo v3 logic."""
_, width, h, w = patches.shape
device = patches.device
- y, x = torch.meshgrid(torch.arange(h, device=device),
- torch.arange(w, device=device), indexing='ij')
+ y, x = torch.meshgrid(
+ torch.arange(h, device=device),
+ torch.arange(w, device=device),
+ indexing='ij',
+ )
if width % 4 != 0:
raise ValueError('Width must be mult of 4 for sincos posemb.')
omega = torch.arange(width // 4, device=device) / (width // 4 - 1)
- omega = 1. / (temperature**omega)
+ omega = 1.0 / (temperature**omega)
y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
@@ -40,10 +42,11 @@ class MlpBlock(nn.Module):
"""Transformer MLP / feed-forward block."""
def __init__(
- self,
- width: int,
- mlp_dim: Optional[int] = None, # Defaults to 4x input dim.
- use_glu: bool = False) -> None:
+ self,
+ width: int,
+ mlp_dim: Optional[int] = None, # Defaults to 4x input dim.
+ use_glu: bool = False,
+ ) -> None:
super().__init__()
self.width = width
@@ -70,7 +73,6 @@ def reset_parameters(self) -> None:
module.bias.data.normal_(std=1e-6)
def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor:
-
x = self.linear1(x)
x = self.act_fnc(x)
@@ -93,7 +95,8 @@ def __init__(self, width: int, num_heads: int = 8) -> None:
self.num_heads = num_heads
assert width % num_heads == 0, (
- 'Memory dimension must be divisible by number of heads.')
+ 'Memory dimension must be divisible by number of heads.'
+ )
self.head_dim = int(width / num_heads)
self.all_head_dim = self.num_heads * self.head_dim
@@ -109,7 +112,7 @@ def reset_parameters(self) -> None:
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight.data)
if module.bias is not None:
- nn.init.constant_(module.bias.data, 0.)
+ nn.init.constant_(module.bias.data, 0.0)
def transpose_for_scores(self, x: spec.Tensor) -> spec.Tensor:
new_x_shape = x.size()[:-1] + (self.num_heads, self.head_dim)
@@ -117,7 +120,6 @@ def transpose_for_scores(self, x: spec.Tensor) -> spec.Tensor:
return x.permute(0, 2, 1, 3)
def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor:
-
mixed_query_layer = self.query(x)
key_layer = self.transpose_for_scores(self.key(x))
@@ -141,12 +143,14 @@ def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor:
class Encoder1DBlock(nn.Module):
"""Single transformer encoder block (MHSA + MLP)."""
- def __init__(self,
- width: int,
- mlp_dim: Optional[int] = None,
- num_heads: int = 12,
- use_glu: bool = False,
- use_post_layer_norm: bool = False) -> None:
+ def __init__(
+ self,
+ width: int,
+ mlp_dim: Optional[int] = None,
+ num_heads: int = 12,
+ use_glu: bool = False,
+ use_post_layer_norm: bool = False,
+ ) -> None:
super().__init__()
self.width = width
@@ -159,10 +163,10 @@ def __init__(self,
self.self_attention1 = SelfAttention(self.width, self.num_heads)
self.layer_norm2 = nn.LayerNorm(self.width, eps=1e-6)
self.mlp3 = MlpBlock(
- width=self.width, mlp_dim=self.mlp_dim, use_glu=self.use_glu)
+ width=self.width, mlp_dim=self.mlp_dim, use_glu=self.use_glu
+ )
def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor:
-
if not self.use_post_layer_norm:
y = self.layer_norm0(x)
y = self.self_attention1(y, dropout_rate)
@@ -191,13 +195,15 @@ def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor:
class Encoder(nn.Module):
"""Transformer Model Encoder for sequence to sequence translation."""
- def __init__(self,
- depth: int,
- width: int,
- mlp_dim: Optional[int] = None,
- num_heads: int = 12,
- use_glu: bool = False,
- use_post_layer_norm: bool = False) -> None:
+ def __init__(
+ self,
+ depth: int,
+ width: int,
+ mlp_dim: Optional[int] = None,
+ num_heads: int = 12,
+ use_glu: bool = False,
+ use_post_layer_norm: bool = False,
+ ) -> None:
super().__init__()
self.depth = depth
@@ -207,13 +213,18 @@ def __init__(self,
self.use_glu = use_glu
self.use_post_layer_norm = use_post_layer_norm
- self.net = nn.ModuleList([
- Encoder1DBlock(self.width,
- self.mlp_dim,
- self.num_heads,
- self.use_glu,
- self.use_post_layer_norm) for _ in range(depth)
- ])
+ self.net = nn.ModuleList(
+ [
+ Encoder1DBlock(
+ self.width,
+ self.mlp_dim,
+ self.num_heads,
+ self.use_glu,
+ self.use_post_layer_norm,
+ )
+ for _ in range(depth)
+ ]
+ )
if not self.use_post_layer_norm:
self.encoder_norm = nn.LayerNorm(self.width, eps=1e-6)
@@ -233,10 +244,9 @@ def forward(self, x: spec.Tensor, dropout_rate: float) -> spec.Tensor:
class MAPHead(nn.Module):
"""Multihead Attention Pooling."""
- def __init__(self,
- width: int,
- mlp_dim: Optional[int] = None,
- num_heads: int = 12):
+ def __init__(
+ self, width: int, mlp_dim: Optional[int] = None, num_heads: int = 12
+ ):
super().__init__()
self.width = width
self.mlp_dim = mlp_dim
@@ -246,7 +256,8 @@ def __init__(self,
nn.init.xavier_uniform_(self.probe.data)
self.mha = MultiheadAttention(
- self.width, num_heads=self.num_heads, self_attn=False, bias=True)
+ self.width, num_heads=self.num_heads, self_attn=False, bias=True
+ )
self.layer_norm = nn.LayerNorm(self.width, eps=1e-6)
self.mlp = MlpBlock(width=self.width, mlp_dim=self.mlp_dim)
@@ -268,19 +279,20 @@ class ViT(nn.Module):
channels: int = 3
def __init__(
- self,
- num_classes: int = 1000,
- patch_size: Tuple[int, int] = (16, 16),
- width: int = 768,
- depth: int = 12,
- mlp_dim: Optional[int] = None, # Defaults to 4x input dim.
- num_heads: int = 12,
- rep_size: Union[int, bool] = True,
- head_zeroinit: bool = True,
- use_glu: bool = False,
- use_post_layer_norm: bool = False,
- use_map: bool = False,
- dtype: Any = torch.float32) -> None:
+ self,
+ num_classes: int = 1000,
+ patch_size: Tuple[int, int] = (16, 16),
+ width: int = 768,
+ depth: int = 12,
+ mlp_dim: Optional[int] = None, # Defaults to 4x input dim.
+ num_heads: int = 12,
+ rep_size: Union[int, bool] = True,
+ head_zeroinit: bool = True,
+ use_glu: bool = False,
+ use_post_layer_norm: bool = False,
+ use_map: bool = False,
+ dtype: Any = torch.float32,
+ ) -> None:
super().__init__()
self.num_classes = num_classes
@@ -301,19 +313,21 @@ def __init__(
self.pre_logits = nn.Linear(self.width, rep_size)
self.conv_patch_extract = nn.Conv2d(
- self.channels,
- self.width,
- self.patch_size,
- stride=self.patch_size,
- padding='valid')
+ self.channels,
+ self.width,
+ self.patch_size,
+ stride=self.patch_size,
+ padding='valid',
+ )
self.encoder = Encoder(
- depth=self.depth,
- width=self.width,
- mlp_dim=self.mlp_dim,
- num_heads=self.num_heads,
- use_glu=self.use_glu,
- use_post_layer_norm=self.use_post_layer_norm)
+ depth=self.depth,
+ width=self.width,
+ mlp_dim=self.mlp_dim,
+ num_heads=self.num_heads,
+ use_glu=self.use_glu,
+ use_post_layer_norm=self.use_post_layer_norm,
+ )
if self.num_classes:
self.head = nn.Linear(self.width, self.num_classes)
@@ -333,18 +347,17 @@ def reset_parameters(self) -> None:
if self.num_classes:
if self.head_zeroinit:
- nn.init.constant_(self.head.weight.data, 0.)
- nn.init.constant_(self.head.bias.data, 0.)
+ nn.init.constant_(self.head.weight.data, 0.0)
+ nn.init.constant_(self.head.bias.data, 0.0)
else:
init_utils.pytorch_default_init(self.head)
def get_posemb(self, x: spec.Tensor) -> spec.Tensor:
return posemb_sincos_2d(x).type(self.dtype)
- def forward(self,
- x: spec.Tensor,
- dropout_rate: float = DROPOUT_RATE) -> spec.Tensor:
-
+ def forward(
+ self, x: spec.Tensor, dropout_rate: float = DROPOUT_RATE
+ ) -> spec.Tensor:
# Patch extraction.
x = self.conv_patch_extract(x)
diff --git a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py
index d43e90e80..9c6faf70b 100644
--- a/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py
+++ b/algoperf/workloads/imagenet_vit/imagenet_pytorch/workload.py
@@ -6,29 +6,30 @@
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
-from algoperf import param_utils
-from algoperf import pytorch_utils
-from algoperf import spec
-from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import \
- ImagenetResNetWorkload
+from algoperf import param_utils, pytorch_utils, spec
+from algoperf.workloads.imagenet_resnet.imagenet_pytorch.workload import (
+ ImagenetResNetWorkload,
+)
from algoperf.workloads.imagenet_vit.imagenet_pytorch import models
-from algoperf.workloads.imagenet_vit.workload import BaseImagenetVitWorkload
-from algoperf.workloads.imagenet_vit.workload import decode_variant
+from algoperf.workloads.imagenet_vit.workload import (
+ BaseImagenetVitWorkload,
+ decode_variant,
+)
USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup()
# Make sure we inherit from the ViT base workload first.
class ImagenetVitWorkload(BaseImagenetVitWorkload, ImagenetResNetWorkload):
-
def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
torch.random.manual_seed(rng[0])
model = models.ViT(
- num_classes=self._num_classes,
- use_glu=self.use_glu,
- use_post_layer_norm=self.use_post_layer_norm,
- use_map=self.use_map,
- **decode_variant('S/16'))
+ num_classes=self._num_classes,
+ use_glu=self.use_glu,
+ use_post_layer_norm=self.use_post_layer_norm,
+ use_map=self.use_map,
+ **decode_variant('S/16'),
+ )
self._param_shapes = param_utils.pytorch_param_shapes(model)
self._param_types = param_utils.pytorch_param_types(self._param_shapes)
model.to(DEVICE)
@@ -43,14 +44,14 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key in ['head.weight', 'head.bias']
def model_fn(
- self,
- params: spec.ParameterContainer,
- augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- mode: spec.ForwardPassMode,
- rng: spec.RandomState,
- update_batch_norm: bool,
- dropout_rate: float = models.DROPOUT_RATE
+ self,
+ params: spec.ParameterContainer,
+ augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ mode: spec.ForwardPassMode,
+ rng: spec.RandomState,
+ update_batch_norm: bool,
+ dropout_rate: float = models.DROPOUT_RATE,
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del rng
@@ -65,20 +66,20 @@ def model_fn(
model.train()
contexts = {
- spec.ForwardPassMode.EVAL: torch.no_grad,
- spec.ForwardPassMode.TRAIN: contextlib.nullcontext,
+ spec.ForwardPassMode.EVAL: torch.no_grad,
+ spec.ForwardPassMode.TRAIN: contextlib.nullcontext,
}
with contexts[mode]():
logits_batch = model(
- augmented_and_preprocessed_input_batch['inputs'],
- dropout_rate=dropout_rate)
+ augmented_and_preprocessed_input_batch['inputs'],
+ dropout_rate=dropout_rate,
+ )
return logits_batch, None
class ImagenetVitGluWorkload(ImagenetVitWorkload):
-
@property
def use_glu(self) -> bool:
return True
@@ -93,7 +94,6 @@ def test_target_value(self) -> float:
class ImagenetVitPostLNWorkload(ImagenetVitWorkload):
-
@property
def use_post_layer_norm(self) -> bool:
return True
@@ -108,7 +108,6 @@ def test_target_value(self) -> float:
class ImagenetVitMapWorkload(ImagenetVitWorkload):
-
@property
def use_map(self) -> bool:
return True
diff --git a/algoperf/workloads/imagenet_vit/workload.py b/algoperf/workloads/imagenet_vit/workload.py
index f249ddee8..2a0070ba4 100644
--- a/algoperf/workloads/imagenet_vit/workload.py
+++ b/algoperf/workloads/imagenet_vit/workload.py
@@ -3,8 +3,9 @@
from typing import Dict, Iterator, Optional
from algoperf import spec
-from algoperf.workloads.imagenet_resnet.workload import \
- BaseImagenetResNetWorkload
+from algoperf.workloads.imagenet_resnet.workload import (
+ BaseImagenetResNetWorkload,
+)
def decode_variant(variant: str) -> Dict[str, int]:
@@ -12,46 +13,52 @@ def decode_variant(variant: str) -> Dict[str, int]:
v, patch = variant.split('/')
return {
- # Reference: Table 2 of https://arxiv.org/abs/2106.04560.
- 'width': {
- 'Ti': 192,
- 'S': 384,
- 'M': 512,
- 'B': 768,
- 'L': 1024,
- 'H': 1280,
- 'g': 1408,
- 'G': 1664,
- }[v],
- 'depth': {
- 'Ti': 12,
- 'S': 12,
- 'M': 12,
- 'B': 12,
- 'L': 24,
- 'H': 32,
- 'g': 40,
- 'G': 48,
- }[v],
- 'mlp_dim': {
- 'Ti': 768,
- 'S': 1536,
- 'M': 2048,
- 'B': 3072,
- 'L': 4096,
- 'H': 5120,
- 'g': 6144,
- 'G': 8192,
- }[v],
- 'num_heads': {
- 'Ti': 3, 'S': 6, 'M': 8, 'B': 12, 'L': 16, 'H': 16, 'g': 16, 'G': 16
- }[v],
- 'patch_size': (int(patch), int(patch)),
+ # Reference: Table 2 of https://arxiv.org/abs/2106.04560.
+ 'width': {
+ 'Ti': 192,
+ 'S': 384,
+ 'M': 512,
+ 'B': 768,
+ 'L': 1024,
+ 'H': 1280,
+ 'g': 1408,
+ 'G': 1664,
+ }[v],
+ 'depth': {
+ 'Ti': 12,
+ 'S': 12,
+ 'M': 12,
+ 'B': 12,
+ 'L': 24,
+ 'H': 32,
+ 'g': 40,
+ 'G': 48,
+ }[v],
+ 'mlp_dim': {
+ 'Ti': 768,
+ 'S': 1536,
+ 'M': 2048,
+ 'B': 3072,
+ 'L': 4096,
+ 'H': 5120,
+ 'g': 6144,
+ 'G': 8192,
+ }[v],
+ 'num_heads': {
+ 'Ti': 3,
+ 'S': 6,
+ 'M': 8,
+ 'B': 12,
+ 'L': 16,
+ 'H': 16,
+ 'g': 16,
+ 'G': 16,
+ }[v],
+ 'patch_size': (int(patch), int(patch)),
}
class BaseImagenetVitWorkload(BaseImagenetResNetWorkload):
-
@property
def validation_target_value(self) -> float:
return 1 - 0.22691 # 0.77309
@@ -88,25 +95,28 @@ def eval_period_time_sec(self) -> int:
return 7 * 60 # 7 mins.
def _build_dataset(
- self,
- data_rng: spec.RandomState,
- split: str,
- data_dir: str,
- global_batch_size: int,
- cache: Optional[bool] = None,
- repeat_final_dataset: Optional[bool] = None,
- use_mixup: bool = False,
- use_randaug: bool = False) -> Iterator[Dict[str, spec.Tensor]]:
+ self,
+ data_rng: spec.RandomState,
+ split: str,
+ data_dir: str,
+ global_batch_size: int,
+ cache: Optional[bool] = None,
+ repeat_final_dataset: Optional[bool] = None,
+ use_mixup: bool = False,
+ use_randaug: bool = False,
+ ) -> Iterator[Dict[str, spec.Tensor]]:
# We use mixup and Randaugment for ViT workloads.
use_mixup = use_randaug = split == 'train'
- return super()._build_dataset(data_rng,
- split,
- data_dir,
- global_batch_size,
- cache,
- repeat_final_dataset,
- use_mixup,
- use_randaug)
+ return super()._build_dataset(
+ data_rng,
+ split,
+ data_dir,
+ global_batch_size,
+ cache,
+ repeat_final_dataset,
+ use_mixup,
+ use_randaug,
+ )
@property
def step_hint(self) -> int:
diff --git a/algoperf/workloads/librispeech_conformer/input_pipeline.py b/algoperf/workloads/librispeech_conformer/input_pipeline.py
index 1310e7b59..570db07b3 100644
--- a/algoperf/workloads/librispeech_conformer/input_pipeline.py
+++ b/algoperf/workloads/librispeech_conformer/input_pipeline.py
@@ -10,7 +10,6 @@
class LibriSpeechDataset(torch.utils.data.Dataset):
-
def __init__(self, split, data_dir):
super().__init__()
self.data_dir = data_dir
@@ -38,13 +37,14 @@ def __getitem__(self, index):
audio_paddings = np.zeros_like(audio, dtype=np.float32)
audio_paddings = np.pad(
- audio_paddings, (0, 320000 - audio.shape[0]), constant_values=1.0)
+ audio_paddings, (0, 320000 - audio.shape[0]), constant_values=1.0
+ )
audio = np.pad(audio, (0, 320000 - audio.shape[0]), constant_values=0.0)
target_paddings = np.zeros_like(targets, dtype=np.float32)
target_paddings = np.pad(
- target_paddings, (0, 256 - target_paddings.shape[0]),
- constant_values=1.0)
+ target_paddings, (0, 256 - target_paddings.shape[0]), constant_values=1.0
+ )
targets = np.pad(targets, (0, 256 - targets.shape[0]), constant_values=0)
audio = audio.astype(np.float32)
audio_paddings = audio_paddings.astype(np.float32)
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/librispeech_preprocessor.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/librispeech_preprocessor.py
index 9f45434d9..bd36b1bb9 100644
--- a/algoperf/workloads/librispeech_conformer/librispeech_jax/librispeech_preprocessor.py
+++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/librispeech_preprocessor.py
@@ -21,174 +21,175 @@
_MEL_HIGH_FREQUENCY_Q = 1127.0
LIBRISPEECH_MEAN_VECTOR = [
- -7.6047816276550293,
- -7.1206226348876953,
- -6.8864245414733887,
- -6.8705768585205078,
- -6.9667720794677734,
- -7.1084094047546387,
- -6.9528026580810547,
- -6.783994197845459,
- -6.6195521354675293,
- -6.4876265525817871,
- -6.4120659828186035,
- -6.394047737121582,
- -6.4244871139526367,
- -6.3993711471557617,
- -6.5158271789550781,
- -6.7137999534606934,
- -6.8476877212524414,
- -6.9885001182556152,
- -6.9221386909484863,
- -7.146148681640625,
- -7.2040400505065918,
- -7.0537552833557129,
- -7.3140382766723633,
- -7.1223249435424805,
- -7.30251407623291,
- -7.1212143898010254,
- -7.2425732612609863,
- -7.1730537414550781,
- -7.0979413986206055,
- -7.088747501373291,
- -6.9849910736083984,
- -6.8787732124328613,
- -6.7602753639221191,
- -6.6300945281982422,
- -6.5145769119262695,
- -6.4245057106018066,
- -6.356513500213623,
- -6.31787633895874,
- -6.2660770416259766,
- -6.2468328475952148,
- -6.2821526527404785,
- -6.1908388137817383,
- -6.2484354972839355,
- -6.1472640037536621,
- -6.0924725532531738,
- -6.0171003341674805,
- -5.9250402450561523,
- -5.8535833358764648,
- -5.8209109306335449,
- -5.8118929862976074,
- -5.80783748626709,
- -5.7714629173278809,
- -5.7453732490539551,
- -5.7705655097961426,
- -5.7765641212463379,
- -5.7831673622131348,
- -5.7954087257385254,
- -5.7994823455810547,
- -5.8023476600646973,
- -5.8047118186950684,
- -5.8168182373046875,
- -5.8844799995422363,
- -5.9727106094360352,
- -6.0444660186767578,
- -6.1284866333007812,
- -6.2257585525512695,
- -6.3157496452331543,
- -6.39061164855957,
- -6.4928598403930664,
- -6.5498456954956055,
- -6.6054320335388184,
- -6.6508378982543945,
- -6.66917610168457,
- -6.6726889610290527,
- -6.684234619140625,
- -6.6974577903747559,
- -6.75471830368042,
- -6.7949142456054688,
- -6.8634209632873535,
- -6.94186544418335
+ -7.6047816276550293,
+ -7.1206226348876953,
+ -6.8864245414733887,
+ -6.8705768585205078,
+ -6.9667720794677734,
+ -7.1084094047546387,
+ -6.9528026580810547,
+ -6.783994197845459,
+ -6.6195521354675293,
+ -6.4876265525817871,
+ -6.4120659828186035,
+ -6.394047737121582,
+ -6.4244871139526367,
+ -6.3993711471557617,
+ -6.5158271789550781,
+ -6.7137999534606934,
+ -6.8476877212524414,
+ -6.9885001182556152,
+ -6.9221386909484863,
+ -7.146148681640625,
+ -7.2040400505065918,
+ -7.0537552833557129,
+ -7.3140382766723633,
+ -7.1223249435424805,
+ -7.30251407623291,
+ -7.1212143898010254,
+ -7.2425732612609863,
+ -7.1730537414550781,
+ -7.0979413986206055,
+ -7.088747501373291,
+ -6.9849910736083984,
+ -6.8787732124328613,
+ -6.7602753639221191,
+ -6.6300945281982422,
+ -6.5145769119262695,
+ -6.4245057106018066,
+ -6.356513500213623,
+ -6.31787633895874,
+ -6.2660770416259766,
+ -6.2468328475952148,
+ -6.2821526527404785,
+ -6.1908388137817383,
+ -6.2484354972839355,
+ -6.1472640037536621,
+ -6.0924725532531738,
+ -6.0171003341674805,
+ -5.9250402450561523,
+ -5.8535833358764648,
+ -5.8209109306335449,
+ -5.8118929862976074,
+ -5.80783748626709,
+ -5.7714629173278809,
+ -5.7453732490539551,
+ -5.7705655097961426,
+ -5.7765641212463379,
+ -5.7831673622131348,
+ -5.7954087257385254,
+ -5.7994823455810547,
+ -5.8023476600646973,
+ -5.8047118186950684,
+ -5.8168182373046875,
+ -5.8844799995422363,
+ -5.9727106094360352,
+ -6.0444660186767578,
+ -6.1284866333007812,
+ -6.2257585525512695,
+ -6.3157496452331543,
+ -6.39061164855957,
+ -6.4928598403930664,
+ -6.5498456954956055,
+ -6.6054320335388184,
+ -6.6508378982543945,
+ -6.66917610168457,
+ -6.6726889610290527,
+ -6.684234619140625,
+ -6.6974577903747559,
+ -6.75471830368042,
+ -6.7949142456054688,
+ -6.8634209632873535,
+ -6.94186544418335,
]
LIBRISPEECH_STD_VECTOR = [
- 3.4353282451629639,
- 3.5962932109832764,
- 3.7012472152709961,
- 3.7369205951690674,
- 3.7535104751586914,
- 3.693629264831543,
- 3.6922497749328613,
- 3.7641522884368896,
- 3.8419716358184814,
- 3.8999848365783691,
- 3.9294240474700928,
- 3.9317409992218018,
- 3.9139585494995117,
- 3.9031598567962646,
- 3.8691999912261963,
- 3.8155081272125244,
- 3.7644970417022705,
- 3.7099106311798096,
- 3.6965086460113525,
- 3.6003766059875488,
- 3.5493226051330566,
- 3.5465121269226074,
- 3.45003604888916,
- 3.4712812900543213,
- 3.4084610939025879,
- 3.4408135414123535,
- 3.4104881286621094,
- 3.4217638969421387,
- 3.4312851428985596,
- 3.4199209213256836,
- 3.4305806159973145,
- 3.4382665157318115,
- 3.4580366611480713,
- 3.4817991256713867,
- 3.4958710670471191,
- 3.5036792755126953,
- 3.5047574043273926,
- 3.4988734722137451,
- 3.493056058883667,
- 3.4822943210601807,
- 3.459430456161499,
- 3.4612770080566406,
- 3.4559063911437988,
- 3.4755423069000244,
- 3.4971549510955811,
- 3.5326557159423828,
- 3.5705199241638184,
- 3.5920312404632568,
- 3.596907377243042,
- 3.5913500785827637,
- 3.5865931510925293,
- 3.5826809406280518,
- 3.5837743282318115,
- 3.5895791053771973,
- 3.5819313526153564,
- 3.5837869644165039,
- 3.5861184597015381,
- 3.5889589786529541,
- 3.592214822769165,
- 3.5939455032348633,
- 3.5856630802154541,
- 3.5884113311767578,
- 3.5921022891998291,
- 3.5870490074157715,
- 3.5806570053100586,
- 3.5731067657470703,
- 3.5617532730102539,
- 3.54980731010437,
- 3.5527374744415283,
- 3.5475366115570068,
- 3.5387849807739258,
- 3.5256178379058838,
- 3.5031836032867432,
- 3.4922726154327393,
- 3.4879646301269531,
- 3.4725594520568848,
- 3.4558389186859131,
- 3.4351828098297119,
- 3.4284293651580811,
- 3.4299170970916748
+ 3.4353282451629639,
+ 3.5962932109832764,
+ 3.7012472152709961,
+ 3.7369205951690674,
+ 3.7535104751586914,
+ 3.693629264831543,
+ 3.6922497749328613,
+ 3.7641522884368896,
+ 3.8419716358184814,
+ 3.8999848365783691,
+ 3.9294240474700928,
+ 3.9317409992218018,
+ 3.9139585494995117,
+ 3.9031598567962646,
+ 3.8691999912261963,
+ 3.8155081272125244,
+ 3.7644970417022705,
+ 3.7099106311798096,
+ 3.6965086460113525,
+ 3.6003766059875488,
+ 3.5493226051330566,
+ 3.5465121269226074,
+ 3.45003604888916,
+ 3.4712812900543213,
+ 3.4084610939025879,
+ 3.4408135414123535,
+ 3.4104881286621094,
+ 3.4217638969421387,
+ 3.4312851428985596,
+ 3.4199209213256836,
+ 3.4305806159973145,
+ 3.4382665157318115,
+ 3.4580366611480713,
+ 3.4817991256713867,
+ 3.4958710670471191,
+ 3.5036792755126953,
+ 3.5047574043273926,
+ 3.4988734722137451,
+ 3.493056058883667,
+ 3.4822943210601807,
+ 3.459430456161499,
+ 3.4612770080566406,
+ 3.4559063911437988,
+ 3.4755423069000244,
+ 3.4971549510955811,
+ 3.5326557159423828,
+ 3.5705199241638184,
+ 3.5920312404632568,
+ 3.596907377243042,
+ 3.5913500785827637,
+ 3.5865931510925293,
+ 3.5826809406280518,
+ 3.5837743282318115,
+ 3.5895791053771973,
+ 3.5819313526153564,
+ 3.5837869644165039,
+ 3.5861184597015381,
+ 3.5889589786529541,
+ 3.592214822769165,
+ 3.5939455032348633,
+ 3.5856630802154541,
+ 3.5884113311767578,
+ 3.5921022891998291,
+ 3.5870490074157715,
+ 3.5806570053100586,
+ 3.5731067657470703,
+ 3.5617532730102539,
+ 3.54980731010437,
+ 3.5527374744415283,
+ 3.5475366115570068,
+ 3.5387849807739258,
+ 3.5256178379058838,
+ 3.5031836032867432,
+ 3.4922726154327393,
+ 3.4879646301269531,
+ 3.4725594520568848,
+ 3.4558389186859131,
+ 3.4351828098297119,
+ 3.4284293651580811,
+ 3.4299170970916748,
]
@struct.dataclass
class LibrispeechPreprocessingConfig:
"""Config to hold all preprocessing options for librispeech dataset."""
+
sample_rate: float = 16000.0
frame_size_ms: float = 25.0
frame_step_ms: float = 10.0
@@ -208,8 +209,9 @@ class LibrispeechPreprocessingConfig:
def _hertz_to_mel(frequencies_hertz):
"""Convert hertz to mel."""
- return _MEL_HIGH_FREQUENCY_Q * jnp.log(1.0 + (frequencies_hertz /
- _MEL_BREAK_FREQUENCY_HERTZ))
+ return _MEL_HIGH_FREQUENCY_Q * jnp.log(
+ 1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ)
+ )
def _pad_end_length(num_timesteps, frame_step, frame_size):
@@ -221,11 +223,13 @@ def _pad_end_length(num_timesteps, frame_step, frame_size):
return padded_length - num_timesteps
-def frame(x,
- frame_length: int,
- frame_step: int,
- pad_end: bool = False,
- pad_value: Union[int, float] = 0.0):
+def frame(
+ x,
+ frame_length: int,
+ frame_step: int,
+ pad_end: bool = False,
+ pad_value: Union[int, float] = 0.0,
+):
"""Slides a window and extract values.
This function extracts `x[:, n:n+frame_length, :]` with sliding `n` with
@@ -251,24 +255,31 @@ def frame(x,
if pad_end:
num_extends = _pad_end_length(num_timesteps, frame_step, frame_length)
x = jnp.pad(
- x, ((0, 0), (0, num_extends), (0, 0)),
- 'constant',
- constant_values=pad_value)
+ x,
+ ((0, 0), (0, num_extends), (0, 0)),
+ 'constant',
+ constant_values=pad_value,
+ )
flat_y = jax.lax.conv_general_dilated_patches(
- x, (frame_length,), (frame_step,),
- 'VALID',
- dimension_numbers=('NTC', 'OIT', 'NTC'))
+ x,
+ (frame_length,),
+ (frame_step,),
+ 'VALID',
+ dimension_numbers=('NTC', 'OIT', 'NTC'),
+ )
ret = flat_y.reshape(flat_y.shape[:-1] + (num_channels, frame_length))
return ret.transpose((0, 1, 3, 2))
-def linear_to_mel_weight_matrix(num_mel_bins: int = 20,
- num_spectrogram_bins: int = 129,
- sample_rate: Union[int, float] = 8000,
- lower_edge_hertz: Union[int, float] = 125.0,
- upper_edge_hertz: Union[int, float] = 3800.0,
- dtype: Any = jnp.float32):
+def linear_to_mel_weight_matrix(
+ num_mel_bins: int = 20,
+ num_spectrogram_bins: int = 129,
+ sample_rate: Union[int, float] = 8000,
+ lower_edge_hertz: Union[int, float] = 125.0,
+ upper_edge_hertz: Union[int, float] = 3800.0,
+ dtype: Any = jnp.float32,
+):
r"""Jax-port of `tf.signal.linear_to_mel_weight_matrix`.
Args:
@@ -300,23 +311,29 @@ def linear_to_mel_weight_matrix(num_mel_bins: int = 20,
if num_mel_bins <= 0:
raise ValueError('num_mel_bins must be positive. Got: %s' % num_mel_bins)
if lower_edge_hertz < 0.0:
- raise ValueError('lower_edge_hertz must be non-negative. Got: %s' %
- lower_edge_hertz)
+ raise ValueError(
+ 'lower_edge_hertz must be non-negative. Got: %s' % lower_edge_hertz
+ )
if lower_edge_hertz >= upper_edge_hertz:
- raise ValueError('lower_edge_hertz %.1f >= upper_edge_hertz %.1f' %
- (lower_edge_hertz, upper_edge_hertz))
+ raise ValueError(
+ 'lower_edge_hertz %.1f >= upper_edge_hertz %.1f'
+ % (lower_edge_hertz, upper_edge_hertz)
+ )
if sample_rate <= 0.0:
raise ValueError('sample_rate must be positive. Got: %s' % sample_rate)
if upper_edge_hertz > sample_rate / 2:
- raise ValueError('upper_edge_hertz must not be larger than the Nyquist '
- 'frequency (sample_rate / 2). Got %s for sample_rate: %s' %
- (upper_edge_hertz, sample_rate))
+ raise ValueError(
+ 'upper_edge_hertz must not be larger than the Nyquist '
+ 'frequency (sample_rate / 2). Got %s for sample_rate: %s'
+ % (upper_edge_hertz, sample_rate)
+ )
# HTK excludes the spectrogram DC bin.
bands_to_zero = 1
nyquist_hertz = sample_rate / 2.0
linear_frequencies = jnp.linspace(
- 0.0, nyquist_hertz, num_spectrogram_bins, dtype=dtype)[bands_to_zero:]
+ 0.0, nyquist_hertz, num_spectrogram_bins, dtype=dtype
+ )[bands_to_zero:]
spectrogram_bins_mel = _hertz_to_mel(linear_frequencies)[:, jnp.newaxis]
# Compute num_mel_bins triples of (lower_edge, center, upper_edge). The
@@ -324,10 +341,11 @@ def linear_to_mel_weight_matrix(num_mel_bins: int = 20,
# Accordingly, we divide [lower_edge_hertz, upper_edge_hertz] into
# num_mel_bins + 2 pieces.
edges = jnp.linspace(
- _hertz_to_mel(lower_edge_hertz),
- _hertz_to_mel(upper_edge_hertz),
- num_mel_bins + 2,
- dtype=dtype)
+ _hertz_to_mel(lower_edge_hertz),
+ _hertz_to_mel(upper_edge_hertz),
+ num_mel_bins + 2,
+ dtype=dtype,
+ )
# Split the triples up and reshape them into [1, num_mel_bins] tensors.
lower_edge_mel = edges[:-2][jnp.newaxis, :]
@@ -337,9 +355,11 @@ def linear_to_mel_weight_matrix(num_mel_bins: int = 20,
# Calculate lower and upper slopes for every spectrogram bin.
# Line segments are linear in the mel domain, not Hertz.
lower_slopes = (spectrogram_bins_mel - lower_edge_mel) / (
- center_mel - lower_edge_mel)
+ center_mel - lower_edge_mel
+ )
upper_slopes = (upper_edge_mel - spectrogram_bins_mel) / (
- upper_edge_mel - center_mel)
+ upper_edge_mel - center_mel
+ )
# Intersect the line segments with each other and zero.
mel_weights_matrix = jnp.maximum(0.0, jnp.minimum(lower_slopes, upper_slopes))
@@ -366,23 +386,26 @@ def _hanning_greco(win_support, frame_size, dtype):
"""
if frame_size < win_support:
raise ValueError(
- 'Provided frame_size = {} is lower than win_support = {}'.format(
- frame_size, win_support))
+ 'Provided frame_size = {} is lower than win_support = {}'.format(
+ frame_size, win_support
+ )
+ )
arg = jnp.pi * 2.0 / (win_support)
- hann = 0.5 - (0.5 * jnp.cos(arg *
- (jnp.arange(win_support, dtype=dtype) + 0.5)))
+ hann = 0.5 - (
+ 0.5 * jnp.cos(arg * (jnp.arange(win_support, dtype=dtype) + 0.5))
+ )
zero_size = frame_size - win_support
return jnp.pad(hann, [(0, zero_size)])
def _next_pow_of_two(x: Union[int, float]) -> int:
- return int(2**np.ceil(np.log2(x)))
+ return int(2 ** np.ceil(np.log2(x)))
class SpectrogramFrontend(nn.Module):
- """Layer to convert input audio signals from time domain to frequency domain.
- """
+ """Layer to convert input audio signals from time domain to frequency domain."""
+
config: LibrispeechPreprocessingConfig = None
input_scale_factor: float = 1.0
output_log: bool = False
@@ -390,8 +413,9 @@ class SpectrogramFrontend(nn.Module):
def setup(self) -> None:
p = self.config
self._frame_step = int(round(p.sample_rate * p.frame_step_ms / 1000.0))
- self._frame_size = int(round(
- p.sample_rate * p.frame_size_ms / 1000.0)) + 1 # +1 for the preemph
+ self._frame_size = (
+ int(round(p.sample_rate * p.frame_size_ms / 1000.0)) + 1
+ ) # +1 for the preemph
# TF-version has maximum of 512, but it's not always necessary
self.fft_size = _next_pow_of_two(self._frame_size)
@@ -421,32 +445,39 @@ def f(frame_size, dtype):
def _apply_preemphasis(self, framed_signal):
p = self.config
if p.preemph_htk_flavor:
- return jnp.concatenate([
- framed_signal[:, :, :1, :] * (1. - p.preemph),
- (framed_signal[:, :, 1:-1, :] -
- p.preemph * framed_signal[:, :, :-2, :])
- ],
- axis=2)
+ return jnp.concatenate(
+ [
+ framed_signal[:, :, :1, :] * (1.0 - p.preemph),
+ (
+ framed_signal[:, :, 1:-1, :]
+ - p.preemph * framed_signal[:, :, :-2, :]
+ ),
+ ],
+ axis=2,
+ )
else:
- return (framed_signal[:, :, 1:, :] -
- p.preemph * framed_signal[:, :, :-1, :])
+ return (
+ framed_signal[:, :, 1:, :] - p.preemph * framed_signal[:, :, :-1, :]
+ )
def fprop_paddings(self, input_paddings):
p = self.config
if p.pad_end:
- num_extends = _pad_end_length(input_paddings.shape[1],
- self._frame_step,
- self._frame_size)
+ num_extends = _pad_end_length(
+ input_paddings.shape[1], self._frame_step, self._frame_size
+ )
input_paddings = jnp.pad(
- input_paddings, ((0, 0), (0, num_extends)), constant_values=1.0)
+ input_paddings, ((0, 0), (0, num_extends)), constant_values=1.0
+ )
return jax.lax.reduce_window(
- input_paddings,
- init_value=1.0,
- computation=jax.lax.min,
- window_dimensions=[1, self._frame_size],
- window_strides=[1, self._frame_step],
- padding='valid')
+ input_paddings,
+ init_value=1.0,
+ computation=jax.lax.min,
+ window_dimensions=[1, self._frame_size],
+ window_strides=[1, self._frame_step],
+ padding='valid',
+ )
def next_prng_key(self, name='dropout'):
return self.make_rng(name)
@@ -469,7 +500,8 @@ def __call__(self, inputs, input_paddings):
pcm_audio_chunk = inputs.astype(jnp.float32) * self.input_scale_factor
framed_signal = frame(
- pcm_audio_chunk, self._frame_size, self._frame_step, pad_end=p.pad_end)
+ pcm_audio_chunk, self._frame_size, self._frame_step, pad_end=p.pad_end
+ )
if p.preemph != 0.0:
preemphasized = self._apply_preemphasis(framed_signal)
@@ -477,8 +509,10 @@ def __call__(self, inputs, input_paddings):
preemphasized = framed_signal[..., :-1, :]
if p.noise_scale > 0.0:
- noise_signal = jax.random.normal(self.next_prng_key(),
- preemphasized.shape) * p.noise_scale
+ noise_signal = (
+ jax.random.normal(self.next_prng_key(), preemphasized.shape)
+ * p.noise_scale
+ )
else:
noise_signal = jnp.zeros(preemphasized.shape)
@@ -501,8 +535,8 @@ def __call__(self, inputs, input_paddings):
class MelFilterbankFrontend(nn.Module):
- """Layer to compute log mel spectograms from input audio signals.
- """
+ """Layer to compute log mel spectograms from input audio signals."""
+
config: LibrispeechPreprocessingConfig = None
use_divide_stream: bool = True
per_bin_mean: Optional[float] = None
@@ -513,7 +547,8 @@ def setup(self):
input_scale_factor = 2**-15 if self.use_divide_stream else 1.0
self.stft = SpectrogramFrontend(
- p, input_scale_factor=input_scale_factor, output_log=False)
+ p, input_scale_factor=input_scale_factor, output_log=False
+ )
if self.per_bin_mean is None:
per_bin_mean = [0.0] * p.num_bins
@@ -526,9 +561,11 @@ def setup(self):
per_bin_stddev = self.per_bin_stddev
self._normalizer_mean = jnp.array(per_bin_mean)[
- jnp.newaxis, jnp.newaxis, :, jnp.newaxis]
+ jnp.newaxis, jnp.newaxis, :, jnp.newaxis
+ ]
self._normalizer_stddev = jnp.array(per_bin_stddev)[
- jnp.newaxis, jnp.newaxis, :, jnp.newaxis]
+ jnp.newaxis, jnp.newaxis, :, jnp.newaxis
+ ]
@nn.compact
def __call__(self, inputs, input_paddings):
@@ -537,18 +574,21 @@ def __call__(self, inputs, input_paddings):
spect, spect_paddings = self.stft(inputs, input_paddings)
mel_weights = linear_to_mel_weight_matrix(
- num_mel_bins=p.num_bins,
- num_spectrogram_bins=spect.shape[2],
- sample_rate=p.sample_rate,
- lower_edge_hertz=p.lower_edge_hertz,
- upper_edge_hertz=p.upper_edge_hertz)
+ num_mel_bins=p.num_bins,
+ num_spectrogram_bins=spect.shape[2],
+ sample_rate=p.sample_rate,
+ lower_edge_hertz=p.lower_edge_hertz,
+ upper_edge_hertz=p.upper_edge_hertz,
+ )
mel_spectrogram = jnp.einsum('fn,btfc->btnc', mel_weights, spect)
logmel_spectrogram = jnp.log(jnp.maximum(mel_spectrogram, p.output_floor))
normalized_logmel_spectrogram = (
- (logmel_spectrogram - self._normalizer_mean) / self._normalizer_stddev)
+ logmel_spectrogram - self._normalizer_mean
+ ) / self._normalizer_stddev
- normalized_logmel_spectrogram = jnp.squeeze(normalized_logmel_spectrogram,
- -1)
+ normalized_logmel_spectrogram = jnp.squeeze(
+ normalized_logmel_spectrogram, -1
+ )
return normalized_logmel_spectrogram, spect_paddings
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py
index b2eee1c37..9fc2e39ef 100644
--- a/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py
+++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/models.py
@@ -16,17 +16,19 @@
import math
from typing import Any, List, Optional
-from flax import linen as nn
-from flax import struct
import jax
import jax.numpy as jnp
import numpy as np
+from flax import linen as nn
+from flax import struct
from algoperf.jax_utils import Dropout
-from algoperf.workloads.librispeech_conformer.librispeech_jax import \
- librispeech_preprocessor as preprocessor
-from algoperf.workloads.librispeech_conformer.librispeech_jax import \
- spectrum_augmenter
+from algoperf.workloads.librispeech_conformer.librispeech_jax import (
+ librispeech_preprocessor as preprocessor,
+)
+from algoperf.workloads.librispeech_conformer.librispeech_jax import (
+ spectrum_augmenter,
+)
DROPOUT_RATE = 0.1
@@ -34,6 +36,7 @@
@struct.dataclass
class ConformerConfig:
"""Global hyperparameters used to minimize obnoxious kwarg plumbing."""
+
vocab_size: int = 1024
dtype: Any = jnp.float32
encoder_dim: int = 512
@@ -66,6 +69,7 @@ class LayerNorm(nn.Module):
zeros, this differs from default flax implementation of multiplying by scale
and initializing to ones.
"""
+
dim: int = 0
epsilon: float = 1e-6
@@ -79,7 +83,7 @@ def __call__(self, inputs):
var = jnp.mean(jnp.square(inputs - mean), axis=[-1], keepdims=True)
normed_inputs = (inputs - mean) * jax.lax.rsqrt(var + self.epsilon)
- normed_inputs *= (1 + self.scale)
+ normed_inputs *= 1 + self.scale
normed_inputs += self.bias
return normed_inputs
@@ -92,6 +96,7 @@ class Subsample(nn.Module):
encoder_dim: model dimension of conformer.
input_dropout_rate: dropout rate for inputs.
"""
+
encoder_dim: int = 0
@nn.compact
@@ -100,30 +105,32 @@ def __call__(self, inputs, input_paddings, train, dropout_rate=DROPOUT_RATE):
outputs = jnp.expand_dims(inputs, axis=-1)
outputs, output_paddings = Conv2dSubsampling(
- input_channels=1, output_channels=self.encoder_dim)(
- outputs, output_paddings)
+ input_channels=1, output_channels=self.encoder_dim
+ )(outputs, output_paddings)
outputs, output_paddings = Conv2dSubsampling(
- input_channels=self.encoder_dim,
- output_channels=self.encoder_dim)(outputs, output_paddings)
+ input_channels=self.encoder_dim, output_channels=self.encoder_dim
+ )(outputs, output_paddings)
batch_size, subsampled_lengths, subsampled_dims, channels = outputs.shape
outputs = jnp.reshape(
- outputs, (batch_size, subsampled_lengths, subsampled_dims * channels))
+ outputs, (batch_size, subsampled_lengths, subsampled_dims * channels)
+ )
outputs = nn.Dense(
- self.encoder_dim,
- use_bias=True,
- kernel_init=nn.initializers.xavier_uniform())(
- outputs)
+ self.encoder_dim,
+ use_bias=True,
+ kernel_init=nn.initializers.xavier_uniform(),
+ )(outputs)
outputs = outputs + AddPositionalEmbedding(embedding_dim=self.encoder_dim)(
- seq_length=outputs.shape[1])
+ seq_length=outputs.shape[1]
+ )
- outputs = Dropout(
- rate=dropout_rate, deterministic=not train)(
- outputs, rate=dropout_rate)
+ outputs = Dropout(rate=dropout_rate, deterministic=not train)(
+ outputs, rate=dropout_rate
+ )
return outputs, output_paddings
@@ -135,6 +142,7 @@ class Conv2dSubsampling(nn.Module):
2) Also performs strided convolution over input_paddings to return the correct
paddings for downstream layers.
"""
+
input_channels: int = 0
output_channels: int = 0
filter_stride: List[int] = (2, 2)
@@ -142,24 +150,26 @@ class Conv2dSubsampling(nn.Module):
def setup(self):
self.filter_shape = (3, 3, self.input_channels, self.output_channels)
- self.kernel = self.param('kernel',
- nn.initializers.xavier_uniform(),
- self.filter_shape)
+ self.kernel = self.param(
+ 'kernel', nn.initializers.xavier_uniform(), self.filter_shape
+ )
self.bias = self.param(
- 'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels)
+ 'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels
+ )
@nn.compact
def __call__(self, inputs, paddings):
# Computing strided convolution to subsample inputs.
feature_group_count = inputs.shape[3] // self.filter_shape[2]
outputs = jax.lax.conv_general_dilated(
- lhs=inputs,
- rhs=self.kernel,
- window_strides=self.filter_stride,
- padding=self.padding,
- rhs_dilation=(1, 1),
- dimension_numbers=('NHWC', 'HWIO', 'NHWC'),
- feature_group_count=feature_group_count)
+ lhs=inputs,
+ rhs=self.kernel,
+ window_strides=self.filter_stride,
+ padding=self.padding,
+ rhs_dilation=(1, 1),
+ dimension_numbers=('NHWC', 'HWIO', 'NHWC'),
+ feature_group_count=feature_group_count,
+ )
outputs += jnp.reshape(self.bias, (1,) * (outputs.ndim - 1) + (-1,))
outputs = nn.relu(outputs)
@@ -170,58 +180,61 @@ def __call__(self, inputs, paddings):
pad_len = (input_length + stride - 1) // stride * stride - input_length
out_padding = jax.lax.conv_general_dilated(
- lhs=paddings[:, :, None],
- rhs=jnp.ones([1, 1, 1]),
- window_strides=self.filter_stride[:1],
- padding=[(0, pad_len)],
- dimension_numbers=('NHC', 'HIO', 'NHC'))
+ lhs=paddings[:, :, None],
+ rhs=jnp.ones([1, 1, 1]),
+ window_strides=self.filter_stride[:1],
+ padding=[(0, pad_len)],
+ dimension_numbers=('NHC', 'HIO', 'NHC'),
+ )
out_padding = jnp.squeeze(out_padding, axis=-1)
# Mask outputs by correct paddings to ensure padded elements in inputs map
# to padded value in outputs.
- outputs = outputs * \
- (1.0 - jnp.expand_dims(jnp.expand_dims(out_padding, -1), -1))
+ outputs = outputs * (
+ 1.0 - jnp.expand_dims(jnp.expand_dims(out_padding, -1), -1)
+ )
return outputs, out_padding
class FeedForwardModule(nn.Module):
- """Feedforward block of conformer layer.
- """
+ """Feedforward block of conformer layer."""
+
config: ConformerConfig
@nn.compact
- def __call__(self,
- inputs,
- padding_mask=None,
- train=False,
- dropout_rate=DROPOUT_RATE):
+ def __call__(
+ self, inputs, padding_mask=None, train=False, dropout_rate=DROPOUT_RATE
+ ):
config = self.config
inputs = LayerNorm(dim=config.encoder_dim)(inputs)
inputs = nn.Dense(
- config.encoder_dim * config.feed_forward_expansion_factor,
- use_bias=True,
- kernel_init=nn.initializers.xavier_uniform())(
- inputs)
+ config.encoder_dim * config.feed_forward_expansion_factor,
+ use_bias=True,
+ kernel_init=nn.initializers.xavier_uniform(),
+ )(inputs)
if config.activation_function_name == 'swish':
activation_fn = nn.swish
elif config.activation_function_name == 'gelu':
activation_fn = nn.gelu
else:
- raise ValueError('Only "swish" and "gelu" are supported '
- 'config.activation_function_name values, recieved '
- f'{config.activation_function_name}')
+ raise ValueError(
+ 'Only "swish" and "gelu" are supported '
+ 'config.activation_function_name values, recieved '
+ f'{config.activation_function_name}'
+ )
inputs = activation_fn(inputs)
inputs = Dropout(rate=dropout_rate)(
- inputs, deterministic=not train, rate=dropout_rate)
+ inputs, deterministic=not train, rate=dropout_rate
+ )
inputs = inputs * padding_mask
inputs = nn.Dense(
- config.encoder_dim,
- use_bias=True,
- kernel_init=nn.initializers.xavier_uniform())(
- inputs)
+ config.encoder_dim,
+ use_bias=True,
+ kernel_init=nn.initializers.xavier_uniform(),
+ )(inputs)
inputs = inputs * padding_mask
inputs = Dropout(rate=dropout_rate)(inputs, deterministic=not train)
@@ -236,6 +249,7 @@ class AddPositionalEmbedding(nn.Module):
max_len: maximum possible length for the input
posemb_init: positional embedding initializer
"""
+
min_timescale: int = 1
max_timescale: int = 10_000
embedding_dim: int = 512
@@ -244,21 +258,23 @@ class AddPositionalEmbedding(nn.Module):
def __call__(self, seq_length):
position = jnp.arange(seq_length, dtype=jnp.float32)[jnp.newaxis, :]
num_timescales = self.embedding_dim // 2
- log_timescale_increment = (
- math.log(float(self.max_timescale) / float(self.min_timescale)) /
- jnp.maximum(jnp.asarray(num_timescales, dtype=jnp.float32) - 1, 1))
+ log_timescale_increment = math.log(
+ float(self.max_timescale) / float(self.min_timescale)
+ ) / jnp.maximum(jnp.asarray(num_timescales, dtype=jnp.float32) - 1, 1)
inv_timescales = self.min_timescale * jnp.exp(
- jnp.arange(num_timescales, dtype=jnp.float32) *
- -log_timescale_increment)
+ jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment
+ )
scaled_time = (
- position[:, :, jnp.newaxis] *
- inv_timescales[jnp.newaxis, jnp.newaxis, :])
- signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)],
- axis=2).astype(jnp.float32)
+ position[:, :, jnp.newaxis] * inv_timescales[jnp.newaxis, jnp.newaxis, :]
+ )
+ signal = jnp.concatenate(
+ [jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=2
+ ).astype(jnp.float32)
# Force usage of `np` rather than `jnp` to compute static values at trace
# time.
- signal = jnp.pad(signal,
- [[0, 0], [0, 0], [0, np.mod(self.embedding_dim, 2)]])
+ signal = jnp.pad(
+ signal, [[0, 0], [0, 0], [0, np.mod(self.embedding_dim, 2)]]
+ )
return signal
@@ -266,6 +282,7 @@ def __call__(self, seq_length):
# https://github.com/tensorflow/lingvo/blob/7de4ca8fff3cb28c2ecb21bbd7b02a964ce727f7/lingvo/jax/layers/attentions.py#L201
class QueryScaler(nn.Module):
"""A layer to scale individual dims of the query attention matrix."""
+
dim: int = 0
def setup(self):
@@ -275,8 +292,10 @@ def setup(self):
def __call__(self, inputs):
inputs_shape = inputs.shape
if inputs_shape[-1] != self.dim:
- raise ValueError('QueryScaler expects inputs to have'
- ' same last dimension as scaling param.')
+ raise ValueError(
+ 'QueryScaler expects inputs to have'
+ ' same last dimension as scaling param.'
+ )
# 1.0/jax.nn.softplus(0.0) = 1.442695041. Hard code this number so that we
# can avoid unnecessary XLA op fusion mess on TPU.
@@ -291,18 +310,20 @@ def __call__(self, inputs):
# Modifying flax linen default dot product attention function to add
# query scaling, reference to original function here :
# https://github.com/google/flax/blob/a9af38085a7a49b571cf37d375060fd683e74972/flax/linen/attention.py#L121
-def dot_product_attention(query,
- key,
- value,
- bias=None,
- mask=None,
- broadcast_dropout=True,
- dropout_rng=None,
- dropout_rate=0.,
- deterministic=False,
- dtype=jnp.float32,
- precision=None,
- temperature=1.0):
+def dot_product_attention(
+ query,
+ key,
+ value,
+ bias=None,
+ mask=None,
+ broadcast_dropout=True,
+ dropout_rng=None,
+ dropout_rate=0.0,
+ deterministic=False,
+ dtype=jnp.float32,
+ precision=None,
+ temperature=1.0,
+):
"""Computes dot-product attention given query, key, and value.
This is the core function for applying attention based on
@@ -341,29 +362,35 @@ def dot_product_attention(query,
"""
assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.'
assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], (
- 'q, k, v batch dims must match.')
+ 'q, k, v batch dims must match.'
+ )
assert query.shape[-2] == key.shape[-2] == value.shape[-2], (
- 'q, k, v num_heads must match.')
+ 'q, k, v num_heads must match.'
+ )
assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.'
# compute attention weights
query = QueryScaler(dim=query.shape[-1])(query)
attn_weights = nn.attention.dot_product_attention_weights(
- query,
- key,
- bias,
- mask,
- broadcast_dropout,
- dropout_rng,
- dropout_rate,
- deterministic,
- dtype,
- precision)
+ query,
+ key,
+ bias,
+ mask,
+ broadcast_dropout,
+ dropout_rng,
+ dropout_rate,
+ deterministic,
+ dtype,
+ precision,
+ )
# return weighted sum over values for each query position
- return jnp.einsum(
- '...hqk,...khd->...qhd', attn_weights, value,
- precision=precision) * temperature
+ return (
+ jnp.einsum(
+ '...hqk,...khd->...qhd', attn_weights, value, precision=precision
+ )
+ * temperature
+ )
class MultiHeadedSelfAttention(nn.Module):
@@ -375,6 +402,7 @@ class MultiHeadedSelfAttention(nn.Module):
Note: this attention implementation uses a learned scale parameter to scale
query matrix before passing it to flax attention module.
"""
+
config: ConformerConfig = None
@nn.compact
@@ -383,28 +411,30 @@ def __call__(self, inputs, paddings, train, dropout_rate=DROPOUT_RATE):
mask_paddings = 1 - paddings
attention_mask = nn.make_attention_mask(
- mask_paddings > 0, mask_paddings > 0, dtype=jnp.float32)
+ mask_paddings > 0, mask_paddings > 0, dtype=jnp.float32
+ )
inputs = LayerNorm(dim=config.encoder_dim)(inputs)
attention_fn = functools.partial(
- dot_product_attention, temperature=config.attention_temperature)
+ dot_product_attention, temperature=config.attention_temperature
+ )
result = nn.MultiHeadDotProductAttention(
- num_heads=config.num_attention_heads,
- qkv_features=config.encoder_dim,
- decode=False,
- dtype=config.dtype,
- kernel_init=nn.initializers.xavier_uniform(),
- bias_init=nn.initializers.zeros,
- use_bias=True,
- broadcast_dropout=False,
- attention_fn=attention_fn,
- dropout_rate=dropout_rate,
- deterministic=not train)(
- inputs_q=inputs, mask=attention_mask)
-
- result = Dropout(
- rate=dropout_rate, deterministic=not train)(
- result, rate=dropout_rate)
+ num_heads=config.num_attention_heads,
+ qkv_features=config.encoder_dim,
+ decode=False,
+ dtype=config.dtype,
+ kernel_init=nn.initializers.xavier_uniform(),
+ bias_init=nn.initializers.zeros,
+ use_bias=True,
+ broadcast_dropout=False,
+ attention_fn=attention_fn,
+ dropout_rate=dropout_rate,
+ deterministic=not train,
+ )(inputs_q=inputs, mask=attention_mask)
+
+ result = Dropout(rate=dropout_rate, deterministic=not train)(
+ result, rate=dropout_rate
+ )
return result
@@ -421,30 +451,27 @@ class BatchNorm(nn.Module):
and the corresponding defaults for momentum and epsilon have been copied over
from lingvo.
"""
+
config: ConformerConfig
def setup(self):
dim = self.config.encoder_dim
dtype = self.config.dtype
- self.ra_mean = self.variable('batch_stats',
- 'mean',
- lambda s: jnp.zeros(s, dtype),
- dim)
- self.ra_var = self.variable('batch_stats',
- 'var',
- lambda s: jnp.ones(s, dtype),
- dim)
+ self.ra_mean = self.variable(
+ 'batch_stats', 'mean', lambda s: jnp.zeros(s, dtype), dim
+ )
+ self.ra_var = self.variable(
+ 'batch_stats', 'var', lambda s: jnp.ones(s, dtype), dim
+ )
self.gamma = self.param('scale', nn.initializers.zeros, dim, dtype)
self.beta = self.param('bias', nn.initializers.zeros, dim, dtype)
@nn.compact
- def __call__(self,
- inputs,
- input_paddings,
- update_batch_norm,
- use_running_average_bn):
+ def __call__(
+ self, inputs, input_paddings, update_batch_norm, use_running_average_bn
+ ):
rank = inputs.ndim
reduce_over_dims = list(range(0, rank - 1))
@@ -461,23 +488,25 @@ def __call__(self,
mask = 1.0 - padding
sum_v = jnp.sum(inputs * mask, axis=reduce_over_dims, keepdims=True)
count_v = jnp.sum(
- jnp.ones_like(inputs) * mask, axis=reduce_over_dims, keepdims=True)
+ jnp.ones_like(inputs) * mask, axis=reduce_over_dims, keepdims=True
+ )
count_v = jnp.maximum(count_v, 1.0)
mean = sum_v / count_v
sum_vv = jnp.sum(
- (inputs - mean) * (inputs - mean) * mask,
- axis=reduce_over_dims,
- keepdims=True)
+ (inputs - mean) * (inputs - mean) * mask,
+ axis=reduce_over_dims,
+ keepdims=True,
+ )
var = sum_vv / count_v
if update_batch_norm:
- self.ra_mean.value = momentum * \
- self.ra_mean.value + (1 - momentum) * mean
- self.ra_var.value = momentum * \
- self.ra_var.value + (1 - momentum) * var
+ self.ra_mean.value = (
+ momentum * self.ra_mean.value + (1 - momentum) * mean
+ )
+ self.ra_var.value = momentum * self.ra_var.value + (1 - momentum) * var
inv = (1 + self.gamma) / jnp.sqrt(var + epsilon)
bn_output = (inputs - mean) * inv + self.beta
@@ -506,64 +535,68 @@ class ConvolutionBlock(nn.Module):
|
output
"""
+
config: ConformerConfig
@nn.compact
- def __call__(self,
- inputs,
- input_paddings,
- train,
- update_batch_norm,
- use_running_average_bn,
- dropout_rate=DROPOUT_RATE):
+ def __call__(
+ self,
+ inputs,
+ input_paddings,
+ train,
+ update_batch_norm,
+ use_running_average_bn,
+ dropout_rate=DROPOUT_RATE,
+ ):
config = self.config
inputs = LayerNorm(dim=config.encoder_dim)(inputs)
input_gated1 = nn.Dense(
- config.encoder_dim,
- kernel_init=nn.initializers.xavier_uniform(),
- use_bias=True)(
- inputs)
+ config.encoder_dim,
+ kernel_init=nn.initializers.xavier_uniform(),
+ use_bias=True,
+ )(inputs)
input_gated2 = nn.Dense(
- config.encoder_dim,
- kernel_init=nn.initializers.xavier_uniform(),
- use_bias=True)(
- inputs)
+ config.encoder_dim,
+ kernel_init=nn.initializers.xavier_uniform(),
+ use_bias=True,
+ )(inputs)
inputs = input_gated1 * jax.nn.sigmoid(input_gated2)
inputs = inputs * (1 - jnp.expand_dims(input_paddings, -1))
inputs = nn.Conv(
- features=config.encoder_dim,
- kernel_size=(config.convolution_kernel_size,),
- strides=(1,),
- padding='SAME',
- feature_group_count=config.encoder_dim,
- use_bias=False,
- kernel_init=nn.initializers.xavier_uniform())(
- inputs)
-
- inputs = BatchNorm(config)(inputs,
- input_paddings,
- update_batch_norm,
- use_running_average_bn)
+ features=config.encoder_dim,
+ kernel_size=(config.convolution_kernel_size,),
+ strides=(1,),
+ padding='SAME',
+ feature_group_count=config.encoder_dim,
+ use_bias=False,
+ kernel_init=nn.initializers.xavier_uniform(),
+ )(inputs)
+
+ inputs = BatchNorm(config)(
+ inputs, input_paddings, update_batch_norm, use_running_average_bn
+ )
if config.activation_function_name == 'swish':
activation_fn = nn.swish
elif config.activation_function_name == 'gelu':
activation_fn = nn.gelu
else:
- raise ValueError('Only "swish" and "gelu" are supported '
- 'config.activation_function_name values, recieved '
- f'{config.activation_function_name}')
+ raise ValueError(
+ 'Only "swish" and "gelu" are supported '
+ 'config.activation_function_name values, recieved '
+ f'{config.activation_function_name}'
+ )
inputs = activation_fn(inputs)
inputs = nn.Dense(
- config.encoder_dim, kernel_init=nn.initializers.xavier_uniform())(
- inputs)
+ config.encoder_dim, kernel_init=nn.initializers.xavier_uniform()
+ )(inputs)
- inputs = Dropout(
- rate=dropout_rate, deterministic=not train)(
- inputs, rate=dropout_rate)
+ inputs = Dropout(rate=dropout_rate, deterministic=not train)(
+ inputs, rate=dropout_rate
+ )
return inputs
@@ -580,36 +613,42 @@ class ConformerBlock(nn.Module):
y = layer_norm(x)
"""
+
config: ConformerConfig
@nn.compact
- def __call__(self,
- inputs,
- input_paddings,
- train,
- update_batch_norm,
- use_running_average,
- dropout_rate=DROPOUT_RATE):
+ def __call__(
+ self,
+ inputs,
+ input_paddings,
+ train,
+ update_batch_norm,
+ use_running_average,
+ dropout_rate=DROPOUT_RATE,
+ ):
config = self.config
padding_mask = jnp.expand_dims(1 - input_paddings, -1)
inputs = inputs + 0.5 * FeedForwardModule(config=self.config)(
- inputs, padding_mask, train, dropout_rate)
+ inputs, padding_mask, train, dropout_rate
+ )
inputs = inputs + MultiHeadedSelfAttention(config=self.config)(
- inputs, input_paddings, train, dropout_rate=dropout_rate)
-
- inputs = inputs + \
- ConvolutionBlock(config)(inputs,
- input_paddings,
- train,
- update_batch_norm,
- use_running_average,
- dropout_rate
- )
+ inputs, input_paddings, train, dropout_rate=dropout_rate
+ )
+
+ inputs = inputs + ConvolutionBlock(config)(
+ inputs,
+ input_paddings,
+ train,
+ update_batch_norm,
+ use_running_average,
+ dropout_rate,
+ )
inputs = inputs + 0.5 * FeedForwardModule(config=self.config)(
- inputs, padding_mask, train, dropout_rate)
+ inputs, padding_mask, train, dropout_rate
+ )
if config.use_post_layer_norm:
inputs = LayerNorm(dim=config.encoder_dim)(inputs)
@@ -624,27 +663,30 @@ class Conformer(nn.Module):
for each time step. The output is then fed into a CTC loss which eliminates
the need for alignment with targets.
"""
+
config: ConformerConfig
def setup(self):
self.specaug = spectrum_augmenter.SpecAug(
- freq_mask_count=self.config.freq_mask_count,
- freq_mask_max_bins=self.config.freq_mask_max_bins,
- time_mask_count=self.config.time_mask_count,
- time_mask_max_frames=self.config.time_mask_max_frames,
- time_mask_max_ratio=self.config.time_mask_max_ratio,
- time_masks_per_frame=self.config.time_masks_per_frame,
- use_dynamic_time_mask_max_frames=self.config
- .use_dynamic_time_mask_max_frames)
+ freq_mask_count=self.config.freq_mask_count,
+ freq_mask_max_bins=self.config.freq_mask_max_bins,
+ time_mask_count=self.config.time_mask_count,
+ time_mask_max_frames=self.config.time_mask_max_frames,
+ time_mask_max_ratio=self.config.time_mask_max_ratio,
+ time_masks_per_frame=self.config.time_masks_per_frame,
+ use_dynamic_time_mask_max_frames=self.config.use_dynamic_time_mask_max_frames,
+ )
@nn.compact
- def __call__(self,
- inputs,
- input_paddings,
- train,
- update_batch_norm: Optional[bool] = None,
- use_running_average_bn: Optional[bool] = None,
- dropout_rate: float = DROPOUT_RATE):
+ def __call__(
+ self,
+ inputs,
+ input_paddings,
+ train,
+ update_batch_norm: Optional[bool] = None,
+ use_running_average_bn: Optional[bool] = None,
+ dropout_rate: float = DROPOUT_RATE,
+ ):
config = self.config
outputs = inputs
@@ -661,8 +703,8 @@ def __call__(self,
outputs, output_paddings = preprocessor.MelFilterbankFrontend(
preprocessing_config,
per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR,
- per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR)(
- outputs, output_paddings)
+ per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR,
+ )(outputs, output_paddings)
# Ablate random parts of input along temporal and frequency dimension
# following the specaug procedure in https://arxiv.org/abs/1904.08779.
@@ -670,24 +712,26 @@ def __call__(self,
outputs, output_paddings = self.specaug(outputs, output_paddings)
outputs, output_paddings = Subsample(
- encoder_dim=config.encoder_dim,)(
- outputs, output_paddings, train, dropout_rate=dropout_rate)
+ encoder_dim=config.encoder_dim,
+ )(outputs, output_paddings, train, dropout_rate=dropout_rate)
# Run the conformer encoder layers.
for _ in range(config.num_encoder_layers):
- outputs = ConformerBlock(config)(outputs,
- output_paddings,
- train,
- update_batch_norm,
- use_running_average_bn,
- dropout_rate)
+ outputs = ConformerBlock(config)(
+ outputs,
+ output_paddings,
+ train,
+ update_batch_norm,
+ use_running_average_bn,
+ dropout_rate,
+ )
outputs = LayerNorm(config.encoder_dim)(outputs)
# Run the decoder which in this case is a trivial projection layer.
outputs = nn.Dense(
- config.vocab_size,
- use_bias=True,
- kernel_init=nn.initializers.xavier_uniform())(
- outputs)
+ config.vocab_size,
+ use_bias=True,
+ kernel_init=nn.initializers.xavier_uniform(),
+ )(outputs)
return outputs, output_paddings
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py
index 2a6f73d4d..d9c1e301b 100644
--- a/algoperf/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py
+++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/spectrum_augmenter.py
@@ -17,6 +17,7 @@ class SpecAug(nn.Module):
This is an essential component in speech recognition models that helps achieve
better word error rates.
"""
+
freq_mask_count: int = 2
freq_mask_max_bins: int = 27
time_mask_count: int = 10
@@ -28,26 +29,30 @@ class SpecAug(nn.Module):
def next_prng_key(self, name='dropout'):
return self.make_rng(name)
- def _get_mask(self,
- batch_size,
- choose_range,
- mask_size,
- max_length=None,
- masks_per_frame=0.0,
- multiplicity=1,
- max_ratio=1.0):
+ def _get_mask(
+ self,
+ batch_size,
+ choose_range,
+ mask_size,
+ max_length=None,
+ masks_per_frame=0.0,
+ multiplicity=1,
+ max_ratio=1.0,
+ ):
# Sample lengths for multiple masks.
if max_length and max_length > 0:
max_length = jnp.tile(max_length, (batch_size,))
else:
max_length = choose_range * max_ratio
masked_portion = jax.random.uniform(
- key=self.next_prng_key(),
- shape=(batch_size, multiplicity),
- minval=0.0,
- maxval=1.0)
- masked_frame_size = jnp.einsum('b,bm->bm', max_length,
- masked_portion).astype(jnp.int32)
+ key=self.next_prng_key(),
+ shape=(batch_size, multiplicity),
+ minval=0.0,
+ maxval=1.0,
+ )
+ masked_frame_size = jnp.einsum(
+ 'b,bm->bm', max_length, masked_portion
+ ).astype(jnp.int32)
# Make sure the sampled length was smaller than max_ratio * length_bound.
# Note that sampling in this way was biased
# (shorter sequence may over-masked.)
@@ -57,7 +62,8 @@ def _get_mask(self,
# Choose starting point.
random_start = jax.random.uniform(
- key=self.next_prng_key(), shape=(batch_size, multiplicity), maxval=1.0)
+ key=self.next_prng_key(), shape=(batch_size, multiplicity), maxval=1.0
+ )
start_with_in_valid_range = random_start * (choose_range - length + 1)
start = start_with_in_valid_range.astype(jnp.int32)
@@ -78,11 +84,13 @@ def _get_mask(self,
# Sum masks with appropriate multiplicity.
if masks_per_frame > 0:
multiplicity_weights = jnp.tile(
- jnp.expand_dims(jnp.arange(multiplicity, dtype=jnp.int32), 0),
- [batch_size, 1])
+ jnp.expand_dims(jnp.arange(multiplicity, dtype=jnp.int32), 0),
+ [batch_size, 1],
+ )
multiplicity_tensor = masks_per_frame * choose_range
- multiplicity_weights = (multiplicity_weights <
- multiplicity_tensor).astype(jnp.int32)
+ multiplicity_weights = (
+ multiplicity_weights < multiplicity_tensor
+ ).astype(jnp.int32)
pre_mask = jnp.einsum('bmt,bm->bt', pre_mask, multiplicity_weights)
else:
pre_mask = jnp.einsum('bmt->bt', pre_mask)
@@ -98,8 +106,9 @@ def _time_mask(self, inputs, length):
max_ratio = self.time_mask_max_ratio
# If maximum mask length is zero, do nothing.
- if ((time_mask_max_frames == 0 and not use_dynamic_time_mask_max_frames) or
- max_ratio <= 0.0):
+ if (
+ time_mask_max_frames == 0 and not use_dynamic_time_mask_max_frames
+ ) or max_ratio <= 0.0:
return inputs
if multiplicity == 0:
return inputs
@@ -111,13 +120,14 @@ def _time_mask(self, inputs, length):
time_mask_max_frames = None
# Create masks in time direction and apply.
block_arrays = self._get_mask(
- batch_size,
- choose_range=length,
- mask_size=time_length,
- max_length=time_mask_max_frames,
- masks_per_frame=self.time_masks_per_frame,
- multiplicity=multiplicity,
- max_ratio=max_ratio)
+ batch_size,
+ choose_range=length,
+ mask_size=time_length,
+ max_length=time_mask_max_frames,
+ masks_per_frame=self.time_masks_per_frame,
+ multiplicity=multiplicity,
+ max_ratio=max_ratio,
+ )
outputs = jnp.einsum('bxy,bx->bxy', inputs, block_arrays)
return outputs
@@ -136,13 +146,14 @@ def _frequency_mask(self, inputs):
choose_range = jnp.tile(num_freq, (batch_size,))
# Create masks in frequency direction and apply.
block_arrays = self._get_mask(
- batch_size,
- choose_range=choose_range,
- mask_size=num_freq,
- max_length=freq_mask_max_bins,
- masks_per_frame=0.0,
- multiplicity=multiplicity,
- max_ratio=1.0)
+ batch_size,
+ choose_range=choose_range,
+ mask_size=num_freq,
+ max_length=freq_mask_max_bins,
+ masks_per_frame=0.0,
+ multiplicity=multiplicity,
+ max_ratio=1.0,
+ )
outputs = jnp.einsum('bxy,by->bxy', inputs, block_arrays)
return outputs
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py
index 1e1a1d3f8..819e57a69 100644
--- a/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py
+++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/workload.py
@@ -2,31 +2,28 @@
import math
from typing import Dict, Iterator, Optional, Tuple
-from flax import jax_utils
-from flax.core import pop
import flax.linen as nn
import jax
-from jax import lax
import jax.numpy as jnp
import numpy as np
import optax
import torch
+from flax import jax_utils
+from flax.core import pop
+from jax import lax
-from algoperf import data_utils
-from algoperf import param_utils
-from algoperf import spec
-from algoperf.workloads.librispeech_conformer import metrics
-from algoperf.workloads.librispeech_conformer import workload
-from algoperf.workloads.librispeech_conformer.input_pipeline import \
- LibriSpeechDataset
+from algoperf import data_utils, param_utils, spec
+from algoperf.workloads.librispeech_conformer import metrics, workload
+from algoperf.workloads.librispeech_conformer.input_pipeline import (
+ LibriSpeechDataset,
+)
from algoperf.workloads.librispeech_conformer.librispeech_jax import models
class LibriSpeechConformerWorkload(workload.BaseLibrispeechWorkload):
-
- def __init__(self,
- tokenizer_vocab_path: Optional[str] = None,
- use_specaug: bool = True) -> None:
+ def __init__(
+ self, tokenizer_vocab_path: Optional[str] = None, use_specaug: bool = True
+ ) -> None:
super().__init__()
self.metrics_bundle = metrics.get_metrics_bundle(tokenizer_vocab_path)
self.use_specaug = use_specaug
@@ -38,7 +35,8 @@ def __init__(self,
def eval_num_workers(self) -> int:
if self._eval_num_workers is None:
raise ValueError(
- 'eval_num_workers property must be set before workload is used.')
+ 'eval_num_workers property must be set before workload is used.'
+ )
return self._eval_num_workers
@eval_num_workers.setter
@@ -58,8 +56,8 @@ def attention_temperature(self) -> float:
return 1.0
def init_model_fn(
- self,
- rng: spec.RandomState,
+ self,
+ rng: spec.RandomState,
) -> spec.ModelInitState:
"""Conformer model init function.
@@ -71,10 +69,11 @@ def init_model_fn(
else:
activation_function_name = 'swish'
model_config = models.ConformerConfig(
- use_specaug=self.use_specaug,
- attention_temperature=self.attention_temperature,
- use_post_layer_norm=self.use_post_layer_norm,
- activation_function_name=activation_function_name)
+ use_specaug=self.use_specaug,
+ attention_temperature=self.attention_temperature,
+ use_post_layer_norm=self.use_post_layer_norm,
+ activation_function_name=activation_function_name,
+ )
self._model = models.Conformer(model_config)
input_shape = [(320000,), (320000,)]
@@ -85,7 +84,7 @@ def init_model_fn(
params_rng, _ = jax.random.split(rng, 2)
variables = model_init_fn({'params': params_rng}, *fake_input_batch)
- model_state, params = pop(variables, "params")
+ model_state, params = pop(variables, 'params')
self._param_shapes = param_utils.jax_param_shapes(params)
self._param_types = param_utils.jax_param_types(self._param_shapes)
@@ -97,49 +96,52 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key == 'Dense_0'
def model_fn(
- self,
- params: spec.ParameterContainer,
- augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- mode: spec.ForwardPassMode,
- rng: spec.RandomState,
- update_batch_norm: bool,
- use_running_average_bn: Optional[bool] = None,
- dropout_rate: Optional[float] = models.DROPOUT_RATE,
+ self,
+ params: spec.ParameterContainer,
+ augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ mode: spec.ForwardPassMode,
+ rng: spec.RandomState,
+ update_batch_norm: bool,
+ use_running_average_bn: Optional[bool] = None,
+ dropout_rate: Optional[float] = models.DROPOUT_RATE,
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
variables = {'params': params, **model_state}
inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs']
is_train_mode = mode == spec.ForwardPassMode.TRAIN
if update_batch_norm or is_train_mode:
(logits, logit_paddings), new_model_state = self._model.apply(
- variables,
- inputs,
- input_paddings,
- train=True,
- rngs={'dropout' : rng},
- mutable=['batch_stats'],
- use_running_average_bn=use_running_average_bn,
- dropout_rate=dropout_rate)
+ variables,
+ inputs,
+ input_paddings,
+ train=True,
+ rngs={'dropout': rng},
+ mutable=['batch_stats'],
+ use_running_average_bn=use_running_average_bn,
+ dropout_rate=dropout_rate,
+ )
return (logits, logit_paddings), new_model_state
else:
logits, logit_paddings = self._model.apply(
- variables,
- inputs,
- input_paddings,
- train=False,
- mutable=False,
- use_running_average_bn=use_running_average_bn)
+ variables,
+ inputs,
+ input_paddings,
+ train=False,
+ mutable=False,
+ use_running_average_bn=use_running_average_bn,
+ )
return (logits, logit_paddings), model_state
def _build_input_queue(
- self,
- data_rng: spec.RandomState,
- split: str,
- data_dir: str,
- global_batch_size: int,
- cache: Optional[bool] = None,
- repeat_final_dataset: Optional[bool] = None,
- num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]:
+ self,
+ data_rng: spec.RandomState,
+ split: str,
+ data_dir: str,
+ global_batch_size: int,
+ cache: Optional[bool] = None,
+ repeat_final_dataset: Optional[bool] = None,
+ num_batches: Optional[int] = None,
+ ) -> Iterator[Dict[str, spec.Tensor]]:
del data_rng
del cache
del repeat_final_dataset
@@ -158,38 +160,41 @@ def _build_input_queue(
ds = LibriSpeechDataset(split=split, data_dir=data_dir)
dataloader = data_utils.cycle(
- torch.utils.data.DataLoader(
- ds,
- batch_size=global_batch_size,
- shuffle=train,
- sampler=None,
- num_workers=4,
- prefetch_factor=10,
- pin_memory=False,
- drop_last=train,
- ))
+ torch.utils.data.DataLoader(
+ ds,
+ batch_size=global_batch_size,
+ shuffle=train,
+ sampler=None,
+ num_workers=4,
+ prefetch_factor=10,
+ pin_memory=False,
+ drop_last=train,
+ )
+ )
for batch in iter(dataloader):
inputs, input_paddings = batch['inputs']
targets, target_paddings = batch['targets']
numpy_batch = {
- 'inputs': (inputs.numpy(), input_paddings.numpy()),
- 'targets': (targets.numpy(), target_paddings.numpy()),
+ 'inputs': (inputs.numpy(), input_paddings.numpy()),
+ 'targets': (targets.numpy(), target_paddings.numpy()),
}
padded_batch = data_utils.shard_and_maybe_pad_np(
- numpy_batch, padding_value=1.0)
+ numpy_batch, padding_value=1.0
+ )
yield padded_batch
# Does NOT apply regularization, which is left to the submitter to do in
# `update_params`.
def loss_fn(
- self,
- label_batch: Tuple[spec.Tensor, spec.Tensor], # (label_batch, padding)
- logits_batch: Tuple[spec.Tensor, spec.Tensor], # (logits_batch, padding)
- mask_batch: Optional[spec.Tensor] = None,
- label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable
+ self,
+ label_batch: Tuple[spec.Tensor, spec.Tensor], # (label_batch, padding)
+ logits_batch: Tuple[spec.Tensor, spec.Tensor], # (logits_batch, padding)
+ mask_batch: Optional[spec.Tensor] = None,
+ label_smoothing: float = 0.0,
+ ) -> Dict[str, spec.Tensor]: # differentiable
"""Evaluate the (masked) loss function at (label_batch, logits_batch).
Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of
@@ -200,10 +205,9 @@ def loss_fn(
logits, logit_paddings = logits_batch
targets, target_paddings = label_batch
logprobs = nn.log_softmax(logits)
- per_example_losses = self.ctc_loss(logprobs,
- logit_paddings,
- targets,
- target_paddings)
+ per_example_losses = self.ctc_loss(
+ logprobs, logit_paddings, targets, target_paddings
+ )
# mask_batch is assumed to be shape [batch].
if mask_batch is not None:
per_example_losses *= mask_batch
@@ -213,23 +217,26 @@ def loss_fn(
n_valid_examples = jnp.maximum(mask_batch.sum(), 1)
summed_loss = per_example_losses.sum()
return {
- 'summed': summed_loss,
- 'n_valid_examples': n_valid_examples,
- 'per_example': per_example_losses,
+ 'summed': summed_loss,
+ 'n_valid_examples': n_valid_examples,
+ 'per_example': per_example_losses,
}
- def ctc_loss(self,
- logits: spec.Tensor,
- logit_paddings: spec.Tensor,
- labels: spec.Tensor,
- label_paddings: spec.Tensor,
- blank_id: int = 0) -> spec.Tensor:
+ def ctc_loss(
+ self,
+ logits: spec.Tensor,
+ logit_paddings: spec.Tensor,
+ labels: spec.Tensor,
+ label_paddings: spec.Tensor,
+ blank_id: int = 0,
+ ) -> spec.Tensor:
return optax.ctc_loss(
- logits=logits,
- logit_paddings=logit_paddings,
- labels=labels,
- label_paddings=label_paddings,
- blank_id=blank_id)
+ logits=logits,
+ logit_paddings=logit_paddings,
+ labels=labels,
+ label_paddings=label_paddings,
+ blank_id=blank_id,
+ )
# Adapted from lingvo's greedy decoding logic here:
# https://github.com/tensorflow/lingvo/blob/2ee26814c57b7dcead3f0382170f2f3da006f810/lingvo/jax/layers/ctc_objectives.py#L138.
@@ -240,21 +247,22 @@ def sequence_mask(self, lengths: spec.Tensor, maxlen: int) -> spec.Tensor:
c = jnp.less_equal(b, lengths[:, jnp.newaxis]).astype(lengths.dtype)
return c
- def collapse_and_remove_blanks(self,
- labels: spec.Tensor,
- seq_length: spec.Tensor,
- blank_id: int = 0) -> spec.Tensor:
+ def collapse_and_remove_blanks(
+ self, labels: spec.Tensor, seq_length: spec.Tensor, blank_id: int = 0
+ ) -> spec.Tensor:
b, t = labels.shape
# Zap out blank.
blank_mask = 1 - jnp.equal(labels, blank_id)
labels = (labels * blank_mask).astype(labels.dtype)
# Mask labels that don't equal previous label.
- label_mask = jnp.concatenate([
+ label_mask = jnp.concatenate(
+ [
jnp.ones_like(labels[:, :1], dtype=jnp.int32),
jnp.not_equal(labels[:, 1:], labels[:, :-1]),
- ],
- axis=1)
+ ],
+ axis=1,
+ )
# Filter labels that aren't in the original sequence.
maxlen = labels.shape[1]
@@ -290,12 +298,14 @@ def collapse_and_remove_blanks(self,
# Reshape back to square batch.
batch_size = labels.shape[0]
new_shape = [batch_size, new_maxlen]
- return (jnp.reshape(flat, new_shape).astype(labels.dtype),
- new_seq_len.astype(seq_length.dtype))
+ return (
+ jnp.reshape(flat, new_shape).astype(labels.dtype),
+ new_seq_len.astype(seq_length.dtype),
+ )
def greedy_decode(
- self, logits: spec.Tensor,
- logit_paddings: spec.Tensor) -> Tuple[spec.Tensor, spec.Tensor]:
+ self, logits: spec.Tensor, logit_paddings: spec.Tensor
+ ) -> Tuple[spec.Tensor, spec.Tensor]:
per_frame_max = jnp.argmax(logits, axis=-1)
seqlen = jnp.sum(1.0 - logit_paddings, axis=-1)
hyp, _ = self.collapse_and_remove_blanks(per_frame_max, seqlen, blank_id=0)
@@ -303,45 +313,51 @@ def greedy_decode(
return hyp, hyp_paddings
@functools.partial(
- jax.pmap,
- axis_name='batch',
- in_axes=(None, 0, 0, 0, None),
- static_broadcasted_argnums=(0,))
+ jax.pmap,
+ axis_name='batch',
+ in_axes=(None, 0, 0, 0, None),
+ static_broadcasted_argnums=(0,),
+ )
def eval_step_pmapped(
- self,
- params: spec.ParameterContainer,
- batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- rng: spec.RandomState) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ self,
+ params: spec.ParameterContainer,
+ batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ rng: spec.RandomState,
+ ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
(logits, logit_paddings), _ = self.model_fn(
- params,
- batch,
- model_state,
- spec.ForwardPassMode.EVAL,
- rng,
- update_batch_norm=False)
+ params,
+ batch,
+ model_state,
+ spec.ForwardPassMode.EVAL,
+ rng,
+ update_batch_norm=False,
+ )
decoded, decoded_paddings = self.greedy_decode(logits, logit_paddings)
loss = self.loss_fn(batch['targets'], (logits, logit_paddings))
targets, target_paddings = batch['targets']
return self.metrics_bundle.gather_from_model_output(
- loss_dict=loss,
- decoded=decoded,
- decoded_paddings=decoded_paddings,
- targets=targets,
- target_paddings=target_paddings,
- axis_name='batch')
-
- def _eval_model_on_split(self,
- split: str,
- num_examples: int,
- global_batch_size: int,
- params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- rng: spec.RandomState,
- data_dir: str,
- global_step: int = 0) -> Dict[str, float]:
+ loss_dict=loss,
+ decoded=decoded,
+ decoded_paddings=decoded_paddings,
+ targets=targets,
+ target_paddings=target_paddings,
+ axis_name='batch',
+ )
+
+ def _eval_model_on_split(
+ self,
+ split: str,
+ num_examples: int,
+ global_batch_size: int,
+ params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ rng: spec.RandomState,
+ data_dir: str,
+ global_step: int = 0,
+ ) -> Dict[str, float]:
"""Run a full evaluation of the model."""
del global_step
if model_state is not None and len(model_state) > 0:
@@ -351,15 +367,15 @@ def _eval_model_on_split(self,
num_batches = int(math.ceil(num_examples / global_batch_size))
if split not in self._eval_iters:
self._eval_iters[split] = self._build_input_queue(
- rng, split, data_dir, global_batch_size, num_batches=num_batches)
+ rng, split, data_dir, global_batch_size, num_batches=num_batches
+ )
metrics_report = None
for _ in range(num_batches):
eval_batch = next(self._eval_iters[split])
- computed_metrics = self.eval_step_pmapped(params,
- eval_batch,
- model_state,
- rng).unreplicate()
+ computed_metrics = self.eval_step_pmapped(
+ params, eval_batch, model_state, rng
+ ).unreplicate()
if metrics_report is None:
metrics_report = computed_metrics
@@ -372,7 +388,8 @@ def _eval_model_on_split(self,
return computed_metrics
def sync_batch_stats(
- self, model_state: spec.ModelAuxiliaryState) -> spec.ModelAuxiliaryState:
+ self, model_state: spec.ModelAuxiliaryState
+ ) -> spec.ModelAuxiliaryState:
# An axis_name is passed to pmap which can then be used by pmean.
# In this case each device has its own version of the batch statistics and
# we average them.
@@ -383,8 +400,8 @@ def sync_batch_stats(
class LibriSpeechConformerAttentionTemperatureWorkload(
- LibriSpeechConformerWorkload):
-
+ LibriSpeechConformerWorkload
+):
@property
def attention_temperature(self) -> float:
return 1.6
@@ -399,7 +416,6 @@ def test_target_value(self) -> float:
class LibriSpeechConformerLayerNormWorkload(LibriSpeechConformerWorkload):
-
@property
def use_post_layer_norm(self) -> bool:
return False
@@ -414,7 +430,6 @@ def test_target_value(self) -> float:
class LibriSpeechConformerGeluWorkload(LibriSpeechConformerWorkload):
-
@property
def use_gelu(self) -> bool:
return True
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py
index 3a2eda4af..647b8ff0c 100644
--- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py
+++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/models.py
@@ -2,20 +2,22 @@
https://github.com/google/init2winit/blob/master/init2winit/model_lib/conformer.py.
"""
+import math
from dataclasses import dataclass
from functools import partial
-import math
from typing import Tuple
import torch
+import torch.nn.functional as F
from torch import nn
from torch.nn import init
-import torch.nn.functional as F
-from algoperf.workloads.librispeech_conformer.librispeech_pytorch import \
- preprocessor
-from algoperf.workloads.librispeech_conformer.librispeech_pytorch.spectrum_augmenter import \
- SpecAug
+from algoperf.workloads.librispeech_conformer.librispeech_pytorch import (
+ preprocessor,
+)
+from algoperf.workloads.librispeech_conformer.librispeech_pytorch.spectrum_augmenter import (
+ SpecAug,
+)
DROPOUT_RATE = 0.1
@@ -23,6 +25,7 @@
@dataclass
class ConformerConfig:
"""Global hyperparameters used to minimize obnoxious kwarg plumbing."""
+
vocab_size: int = 1024
encoder_dim: int = 512
num_attention_heads: int = 8
@@ -58,7 +61,6 @@ def initialize(m):
class LayerNorm(nn.Module):
-
def __init__(self, dim, epsilon=1e-6):
super().__init__()
self.dim = dim
@@ -72,24 +74,25 @@ def forward(self, x):
class Subsample(nn.Module):
-
def __init__(self, encoder_dim: int = 0, num_bins: int = 80):
super().__init__()
self.encoder_dim = encoder_dim
self.conv1 = Conv2dSubsampling(
- input_channels=1, output_channels=encoder_dim)
+ input_channels=1, output_channels=encoder_dim
+ )
self.conv2 = Conv2dSubsampling(
- input_channels=encoder_dim, output_channels=encoder_dim)
+ input_channels=encoder_dim, output_channels=encoder_dim
+ )
self.linear = nn.Linear(
- in_features=self.encoder_dim * num_bins // 4,
- out_features=self.encoder_dim,
- bias=True)
+ in_features=self.encoder_dim * num_bins // 4,
+ out_features=self.encoder_dim,
+ bias=True,
+ )
self.pos_encode = AddPositionalEmbedding(embedding_dim=self.encoder_dim)
def forward(self, inputs, input_paddings, dropout_rate):
-
output_paddings = input_paddings
outputs = inputs[:, None, :, :]
@@ -97,25 +100,27 @@ def forward(self, inputs, input_paddings, dropout_rate):
outputs, output_paddings = self.conv2(outputs, output_paddings)
batch_size, channels, subsampled_lengths, subsampled_dims = outputs.shape
- outputs = outputs.permute(0, 2, 3, 1).reshape(batch_size,
- subsampled_lengths,
- subsampled_dims * channels)
+ outputs = outputs.permute(0, 2, 3, 1).reshape(
+ batch_size, subsampled_lengths, subsampled_dims * channels
+ )
outputs = self.linear(outputs)
outputs = outputs + self.pos_encode(seq_length=outputs.shape[1])
outputs = F.dropout(
- outputs, dropout_rate, training=self.training, inplace=True)
+ outputs, dropout_rate, training=self.training, inplace=True
+ )
return outputs, output_paddings
class Conv2dSubsampling(nn.Module):
-
- def __init__(self,
- input_channels: int,
- output_channels: int,
- filter_stride: Tuple[int] = (2, 2),
- padding: str = 'SAME'):
+ def __init__(
+ self,
+ input_channels: int,
+ output_channels: int,
+ filter_stride: Tuple[int] = (2, 2),
+ padding: str = 'SAME',
+ ):
super().__init__()
self.input_channels = input_channels
@@ -126,7 +131,8 @@ def __init__(self,
self.filter_shape = (output_channels, input_channels, 3, 3)
self.kernel = nn.Parameter(
- torch.nn.init.xavier_uniform_(torch.empty(*self.filter_shape)))
+ torch.nn.init.xavier_uniform_(torch.empty(*self.filter_shape))
+ )
self.bias = nn.Parameter(torch.zeros(output_channels))
self.register_buffer('paddings_kernel', torch.ones([1, 1, 1]))
@@ -156,12 +162,13 @@ def forward(self, inputs, paddings):
else:
in_ = inputs
outputs = F.conv2d(
- input=in_,
- weight=self.kernel,
- bias=self.bias,
- stride=self.filter_stride,
- dilation=(1, 1),
- groups=groups)
+ input=in_,
+ weight=self.kernel,
+ bias=self.bias,
+ stride=self.filter_stride,
+ dilation=(1, 1),
+ groups=groups,
+ )
outputs = F.relu(outputs)
@@ -169,35 +176,37 @@ def forward(self, inputs, paddings):
stride = self.filter_stride[0]
pad_len = (input_length + stride - 1) // stride * stride - input_length
padded_paddings = F.pad(
- paddings[:, None, :], (0, pad_len), mode='constant', value=0)
+ paddings[:, None, :], (0, pad_len), mode='constant', value=0
+ )
out_padding = F.conv1d(
- input=padded_paddings,
- weight=self.paddings_kernel,
- stride=self.filter_stride[:1])
+ input=padded_paddings,
+ weight=self.paddings_kernel,
+ stride=self.filter_stride[:1],
+ )
out_padding = out_padding.squeeze(dim=1)
outputs = outputs * (1 - out_padding[:, None, :, None])
return outputs, out_padding
class FeedForwardModule(nn.Module):
-
def __init__(self, config: ConformerConfig):
super().__init__()
self.config = config
self.ln = LayerNorm(dim=config.encoder_dim)
self.linear1 = nn.Linear(
- in_features=config.encoder_dim,
- out_features=config.encoder_dim * config.feed_forward_expansion_factor,
- bias=True)
+ in_features=config.encoder_dim,
+ out_features=config.encoder_dim * config.feed_forward_expansion_factor,
+ bias=True,
+ )
self.dropout1 = nn.Dropout(p=config.feed_forward_dropout_rate, inplace=True)
self.linear2 = nn.Linear(
- in_features=config.encoder_dim * config.feed_forward_expansion_factor,
- out_features=config.encoder_dim,
- bias=True)
+ in_features=config.encoder_dim * config.feed_forward_expansion_factor,
+ out_features=config.encoder_dim,
+ bias=True,
+ )
def forward(self, inputs, padding_mask, dropout_rate):
-
inputs = self.ln(inputs)
inputs = self.linear1(inputs)
if self.config.activation_function_name == 'swish':
@@ -206,52 +215,58 @@ def forward(self, inputs, padding_mask, dropout_rate):
# Use tanh approximation of GELU which is default for jax
activation_fn = partial(F.gelu, approximate='tanh')
else:
- raise ValueError('Only "swish" and "gelu" are supported '
- 'config.activation_function_name values, recieved '
- f'{self.config.activation_function_name}')
+ raise ValueError(
+ 'Only "swish" and "gelu" are supported '
+ 'config.activation_function_name values, recieved '
+ f'{self.config.activation_function_name}'
+ )
inputs = activation_fn(inputs)
inputs = self.dropout1(inputs)
inputs = inputs * padding_mask
inputs = self.linear2(inputs)
inputs = inputs * padding_mask
inputs = F.dropout(
- inputs, dropout_rate, training=self.training, inplace=True)
+ inputs, dropout_rate, training=self.training, inplace=True
+ )
return inputs
class AddPositionalEmbedding(nn.Module):
-
- def __init__(self,
- min_timescale: int = 1,
- max_timescale: int = 10_000,
- embedding_dim: int = 512):
+ def __init__(
+ self,
+ min_timescale: int = 1,
+ max_timescale: int = 10_000,
+ embedding_dim: int = 512,
+ ):
super().__init__()
self.min_timescale = min_timescale
self.max_timescale = max_timescale
self.embedding_dim = embedding_dim
num_timescales = self.embedding_dim // 2
log_timescale_increment = math.log(
- float(self.max_timescale) / float(self.min_timescale)) / (
- num_timescales - 1)
- inv_timescales = self.min_timescale * \
- torch.exp(torch.arange(num_timescales, dtype=torch.float32)
- * -log_timescale_increment)
+ float(self.max_timescale) / float(self.min_timescale)
+ ) / (num_timescales - 1)
+ inv_timescales = self.min_timescale * torch.exp(
+ torch.arange(num_timescales, dtype=torch.float32)
+ * -log_timescale_increment
+ )
self.register_buffer('inv_timescales', inv_timescales[None, None, :])
def forward(self, seq_length):
position = torch.arange(
- end=seq_length, dtype=torch.float32, device=self.inv_timescales.device)
+ end=seq_length, dtype=torch.float32, device=self.inv_timescales.device
+ )
scaled_time = position[None, :, None] * self.inv_timescales
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2)
if self.embedding_dim % 2:
signal = torch.cat(
- [signal, torch.zeros(signal.shape[0], signal.shape[1], 1)], dim=2)
+ [signal, torch.zeros(signal.shape[0], signal.shape[1], 1)], dim=2
+ )
return signal
class QueryScaler(nn.Module):
-
def __init__(self, dim):
super().__init__()
self.dim = dim
@@ -264,7 +279,6 @@ def forward(self, inputs):
class MHSAwithQS(nn.Module):
-
def __init__(self, config: ConformerConfig):
super().__init__()
self.embed_dim = config.encoder_dim
@@ -281,20 +295,23 @@ def forward(self, inputs, key_padding_mask=None):
q = self.qs(q.view(batch_size, seq_len, self.num_heads, -1)).transpose(1, 2)
k = k.view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)
v = v.view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)
- out = F.scaled_dot_product_attention(
+ out = (
+ F.scaled_dot_product_attention(
query=q,
key=k,
value=v,
attn_mask=~key_padding_mask[:, None, None],
dropout_p=self.attention_dropout_rate,
- ).transpose(1, 2).reshape(batch_size, seq_len, embed_dim)
+ )
+ .transpose(1, 2)
+ .reshape(batch_size, seq_len, embed_dim)
+ )
out = out * self.attention_temperature
out = self.out_proj(out)
return out
class MultiHeadedSelfAttention(nn.Module):
-
def __init__(self, config: ConformerConfig):
super().__init__()
@@ -306,16 +323,16 @@ def __init__(self, config: ConformerConfig):
def forward(self, outputs, paddings, dropout_rate):
outputs = self.ln(outputs)
outputs = self.self_attention(
- outputs,
- key_padding_mask=paddings == 1,
+ outputs,
+ key_padding_mask=paddings == 1,
)
outputs = F.dropout(
- outputs, dropout_rate, training=self.training, inplace=True)
+ outputs, dropout_rate, training=self.training, inplace=True
+ )
return outputs
class BatchNorm(nn.Module):
-
def __init__(self, config: ConformerConfig):
super().__init__()
running_mean = torch.zeros(config.encoder_dim)
@@ -330,8 +347,8 @@ def __init__(self, config: ConformerConfig):
self.epsilon = config.batch_norm_epsilon
def forward(self, inputs, input_paddings):
- #inputs: NHD
- #padding: NH
+ # inputs: NHD
+ # padding: NH
"""
Alternatively:
inputs[input_paddings==0] = F.batch_norm(
@@ -355,9 +372,11 @@ def forward(self, inputs, input_paddings):
var = (torch.square(masked_inp - mean) * mask).sum(dim=(0, 1)) / count
self.running_mean = (1 - self.momentum) * self.running_mean + (
- self.momentum) * mean.detach()
+ self.momentum
+ ) * mean.detach()
self.running_var = (1 - self.momentum) * self.running_var + (
- self.momentum) * var.detach()
+ self.momentum
+ ) * var.detach()
else:
mean = self.running_mean
@@ -369,25 +388,27 @@ def forward(self, inputs, input_paddings):
class ConvolutionBlock(nn.Module):
-
def __init__(self, config):
super().__init__()
self.config = config
self.ln = LayerNorm(dim=config.encoder_dim)
self.lin1 = nn.Linear(
- in_features=config.encoder_dim, out_features=config.encoder_dim)
+ in_features=config.encoder_dim, out_features=config.encoder_dim
+ )
self.lin2 = nn.Linear(
- in_features=config.encoder_dim, out_features=config.encoder_dim)
+ in_features=config.encoder_dim, out_features=config.encoder_dim
+ )
self.conv1 = nn.Conv1d(
- in_channels=config.encoder_dim,
- out_channels=config.encoder_dim,
- kernel_size=(config.convolution_kernel_size,),
- stride=(1,),
- padding='same',
- bias=False,
- groups=config.encoder_dim)
+ in_channels=config.encoder_dim,
+ out_channels=config.encoder_dim,
+ kernel_size=(config.convolution_kernel_size,),
+ stride=(1,),
+ padding='same',
+ bias=False,
+ groups=config.encoder_dim,
+ )
self.bn = BatchNorm(config)
self.lin3 = nn.Linear(config.encoder_dim, config.encoder_dim)
@@ -407,19 +428,21 @@ def forward(self, inputs, input_paddings, dropout_rate):
elif self.config.activation_function_name == 'gelu':
activation_fn = F.gelu
else:
- raise ValueError('Only "swish" and "gelu" are supported '
- 'config.activation_function_name values, recieved '
- f'{self.config.activation_function_name}')
+ raise ValueError(
+ 'Only "swish" and "gelu" are supported '
+ 'config.activation_function_name values, recieved '
+ f'{self.config.activation_function_name}'
+ )
inputs = activation_fn(inputs)
inputs = self.lin3(inputs)
inputs = F.dropout(
- inputs, dropout_rate, training=self.training, inplace=True)
+ inputs, dropout_rate, training=self.training, inplace=True
+ )
return inputs
class ConformerBlock(nn.Module):
-
def __init__(self, config: ConformerConfig):
super().__init__()
@@ -443,28 +466,30 @@ def forward(self, inputs, input_paddings, dropout_rate):
class ConformerEncoderDecoder(nn.Module):
-
def __init__(self, config: ConformerConfig):
super().__init__()
self.config = config
preprocessing_config = preprocessor.PreprocessorConfig()
self.preprocessor = preprocessor.MelFilterbankFrontend(
- preprocessing_config,
- per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR,
- per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR)
+ preprocessing_config,
+ per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR,
+ per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR,
+ )
self.specaug = SpecAug(
- freq_mask_count=config.freq_mask_count,
- freq_mask_max_bins=config.freq_mask_max_bins,
- time_mask_count=config.time_mask_count,
- time_mask_max_frames=config.time_mask_max_frames,
- time_mask_max_ratio=config.time_mask_max_ratio,
- time_masks_per_frame=config.time_masks_per_frame,
- use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames
+ freq_mask_count=config.freq_mask_count,
+ freq_mask_max_bins=config.freq_mask_max_bins,
+ time_mask_count=config.time_mask_count,
+ time_mask_max_frames=config.time_mask_max_frames,
+ time_mask_max_ratio=config.time_mask_max_ratio,
+ time_masks_per_frame=config.time_masks_per_frame,
+ use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames,
)
self.subsample = Subsample(
- encoder_dim=config.encoder_dim, num_bins=preprocessing_config.num_bins)
+ encoder_dim=config.encoder_dim, num_bins=preprocessing_config.num_bins
+ )
self.conformers = nn.ModuleList(
- [ConformerBlock(config) for _ in range(config.num_encoder_layers)])
+ [ConformerBlock(config) for _ in range(config.num_encoder_layers)]
+ )
self.ln = LayerNorm(config.encoder_dim)
self.lin = nn.Linear(config.encoder_dim, config.vocab_size)
@@ -475,8 +500,9 @@ def forward(self, inputs, input_paddings, dropout_rate=DROPOUT_RATE):
outputs, output_paddings = self.preprocessor(outputs, output_paddings)
if self.training and self.config.use_specaug:
outputs, output_paddings = self.specaug(outputs, output_paddings)
- outputs, output_paddings = self.subsample(outputs, output_paddings,
- dropout_rate)
+ outputs, output_paddings = self.subsample(
+ outputs, output_paddings, dropout_rate
+ )
for conformer in self.conformers:
outputs = conformer(outputs, output_paddings, dropout_rate)
outputs = self.ln(outputs)
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/preprocessor.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/preprocessor.py
index 558a0f796..f8c1bd0d2 100644
--- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/preprocessor.py
+++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/preprocessor.py
@@ -16,174 +16,175 @@
_MEL_HIGH_FREQUENCY_Q = 1127.0
LIBRISPEECH_MEAN_VECTOR = [
- -7.6047816276550293,
- -7.1206226348876953,
- -6.8864245414733887,
- -6.8705768585205078,
- -6.9667720794677734,
- -7.1084094047546387,
- -6.9528026580810547,
- -6.783994197845459,
- -6.6195521354675293,
- -6.4876265525817871,
- -6.4120659828186035,
- -6.394047737121582,
- -6.4244871139526367,
- -6.3993711471557617,
- -6.5158271789550781,
- -6.7137999534606934,
- -6.8476877212524414,
- -6.9885001182556152,
- -6.9221386909484863,
- -7.146148681640625,
- -7.2040400505065918,
- -7.0537552833557129,
- -7.3140382766723633,
- -7.1223249435424805,
- -7.30251407623291,
- -7.1212143898010254,
- -7.2425732612609863,
- -7.1730537414550781,
- -7.0979413986206055,
- -7.088747501373291,
- -6.9849910736083984,
- -6.8787732124328613,
- -6.7602753639221191,
- -6.6300945281982422,
- -6.5145769119262695,
- -6.4245057106018066,
- -6.356513500213623,
- -6.31787633895874,
- -6.2660770416259766,
- -6.2468328475952148,
- -6.2821526527404785,
- -6.1908388137817383,
- -6.2484354972839355,
- -6.1472640037536621,
- -6.0924725532531738,
- -6.0171003341674805,
- -5.9250402450561523,
- -5.8535833358764648,
- -5.8209109306335449,
- -5.8118929862976074,
- -5.80783748626709,
- -5.7714629173278809,
- -5.7453732490539551,
- -5.7705655097961426,
- -5.7765641212463379,
- -5.7831673622131348,
- -5.7954087257385254,
- -5.7994823455810547,
- -5.8023476600646973,
- -5.8047118186950684,
- -5.8168182373046875,
- -5.8844799995422363,
- -5.9727106094360352,
- -6.0444660186767578,
- -6.1284866333007812,
- -6.2257585525512695,
- -6.3157496452331543,
- -6.39061164855957,
- -6.4928598403930664,
- -6.5498456954956055,
- -6.6054320335388184,
- -6.6508378982543945,
- -6.66917610168457,
- -6.6726889610290527,
- -6.684234619140625,
- -6.6974577903747559,
- -6.75471830368042,
- -6.7949142456054688,
- -6.8634209632873535,
- -6.94186544418335
+ -7.6047816276550293,
+ -7.1206226348876953,
+ -6.8864245414733887,
+ -6.8705768585205078,
+ -6.9667720794677734,
+ -7.1084094047546387,
+ -6.9528026580810547,
+ -6.783994197845459,
+ -6.6195521354675293,
+ -6.4876265525817871,
+ -6.4120659828186035,
+ -6.394047737121582,
+ -6.4244871139526367,
+ -6.3993711471557617,
+ -6.5158271789550781,
+ -6.7137999534606934,
+ -6.8476877212524414,
+ -6.9885001182556152,
+ -6.9221386909484863,
+ -7.146148681640625,
+ -7.2040400505065918,
+ -7.0537552833557129,
+ -7.3140382766723633,
+ -7.1223249435424805,
+ -7.30251407623291,
+ -7.1212143898010254,
+ -7.2425732612609863,
+ -7.1730537414550781,
+ -7.0979413986206055,
+ -7.088747501373291,
+ -6.9849910736083984,
+ -6.8787732124328613,
+ -6.7602753639221191,
+ -6.6300945281982422,
+ -6.5145769119262695,
+ -6.4245057106018066,
+ -6.356513500213623,
+ -6.31787633895874,
+ -6.2660770416259766,
+ -6.2468328475952148,
+ -6.2821526527404785,
+ -6.1908388137817383,
+ -6.2484354972839355,
+ -6.1472640037536621,
+ -6.0924725532531738,
+ -6.0171003341674805,
+ -5.9250402450561523,
+ -5.8535833358764648,
+ -5.8209109306335449,
+ -5.8118929862976074,
+ -5.80783748626709,
+ -5.7714629173278809,
+ -5.7453732490539551,
+ -5.7705655097961426,
+ -5.7765641212463379,
+ -5.7831673622131348,
+ -5.7954087257385254,
+ -5.7994823455810547,
+ -5.8023476600646973,
+ -5.8047118186950684,
+ -5.8168182373046875,
+ -5.8844799995422363,
+ -5.9727106094360352,
+ -6.0444660186767578,
+ -6.1284866333007812,
+ -6.2257585525512695,
+ -6.3157496452331543,
+ -6.39061164855957,
+ -6.4928598403930664,
+ -6.5498456954956055,
+ -6.6054320335388184,
+ -6.6508378982543945,
+ -6.66917610168457,
+ -6.6726889610290527,
+ -6.684234619140625,
+ -6.6974577903747559,
+ -6.75471830368042,
+ -6.7949142456054688,
+ -6.8634209632873535,
+ -6.94186544418335,
]
LIBRISPEECH_STD_VECTOR = [
- 3.4353282451629639,
- 3.5962932109832764,
- 3.7012472152709961,
- 3.7369205951690674,
- 3.7535104751586914,
- 3.693629264831543,
- 3.6922497749328613,
- 3.7641522884368896,
- 3.8419716358184814,
- 3.8999848365783691,
- 3.9294240474700928,
- 3.9317409992218018,
- 3.9139585494995117,
- 3.9031598567962646,
- 3.8691999912261963,
- 3.8155081272125244,
- 3.7644970417022705,
- 3.7099106311798096,
- 3.6965086460113525,
- 3.6003766059875488,
- 3.5493226051330566,
- 3.5465121269226074,
- 3.45003604888916,
- 3.4712812900543213,
- 3.4084610939025879,
- 3.4408135414123535,
- 3.4104881286621094,
- 3.4217638969421387,
- 3.4312851428985596,
- 3.4199209213256836,
- 3.4305806159973145,
- 3.4382665157318115,
- 3.4580366611480713,
- 3.4817991256713867,
- 3.4958710670471191,
- 3.5036792755126953,
- 3.5047574043273926,
- 3.4988734722137451,
- 3.493056058883667,
- 3.4822943210601807,
- 3.459430456161499,
- 3.4612770080566406,
- 3.4559063911437988,
- 3.4755423069000244,
- 3.4971549510955811,
- 3.5326557159423828,
- 3.5705199241638184,
- 3.5920312404632568,
- 3.596907377243042,
- 3.5913500785827637,
- 3.5865931510925293,
- 3.5826809406280518,
- 3.5837743282318115,
- 3.5895791053771973,
- 3.5819313526153564,
- 3.5837869644165039,
- 3.5861184597015381,
- 3.5889589786529541,
- 3.592214822769165,
- 3.5939455032348633,
- 3.5856630802154541,
- 3.5884113311767578,
- 3.5921022891998291,
- 3.5870490074157715,
- 3.5806570053100586,
- 3.5731067657470703,
- 3.5617532730102539,
- 3.54980731010437,
- 3.5527374744415283,
- 3.5475366115570068,
- 3.5387849807739258,
- 3.5256178379058838,
- 3.5031836032867432,
- 3.4922726154327393,
- 3.4879646301269531,
- 3.4725594520568848,
- 3.4558389186859131,
- 3.4351828098297119,
- 3.4284293651580811,
- 3.4299170970916748
+ 3.4353282451629639,
+ 3.5962932109832764,
+ 3.7012472152709961,
+ 3.7369205951690674,
+ 3.7535104751586914,
+ 3.693629264831543,
+ 3.6922497749328613,
+ 3.7641522884368896,
+ 3.8419716358184814,
+ 3.8999848365783691,
+ 3.9294240474700928,
+ 3.9317409992218018,
+ 3.9139585494995117,
+ 3.9031598567962646,
+ 3.8691999912261963,
+ 3.8155081272125244,
+ 3.7644970417022705,
+ 3.7099106311798096,
+ 3.6965086460113525,
+ 3.6003766059875488,
+ 3.5493226051330566,
+ 3.5465121269226074,
+ 3.45003604888916,
+ 3.4712812900543213,
+ 3.4084610939025879,
+ 3.4408135414123535,
+ 3.4104881286621094,
+ 3.4217638969421387,
+ 3.4312851428985596,
+ 3.4199209213256836,
+ 3.4305806159973145,
+ 3.4382665157318115,
+ 3.4580366611480713,
+ 3.4817991256713867,
+ 3.4958710670471191,
+ 3.5036792755126953,
+ 3.5047574043273926,
+ 3.4988734722137451,
+ 3.493056058883667,
+ 3.4822943210601807,
+ 3.459430456161499,
+ 3.4612770080566406,
+ 3.4559063911437988,
+ 3.4755423069000244,
+ 3.4971549510955811,
+ 3.5326557159423828,
+ 3.5705199241638184,
+ 3.5920312404632568,
+ 3.596907377243042,
+ 3.5913500785827637,
+ 3.5865931510925293,
+ 3.5826809406280518,
+ 3.5837743282318115,
+ 3.5895791053771973,
+ 3.5819313526153564,
+ 3.5837869644165039,
+ 3.5861184597015381,
+ 3.5889589786529541,
+ 3.592214822769165,
+ 3.5939455032348633,
+ 3.5856630802154541,
+ 3.5884113311767578,
+ 3.5921022891998291,
+ 3.5870490074157715,
+ 3.5806570053100586,
+ 3.5731067657470703,
+ 3.5617532730102539,
+ 3.54980731010437,
+ 3.5527374744415283,
+ 3.5475366115570068,
+ 3.5387849807739258,
+ 3.5256178379058838,
+ 3.5031836032867432,
+ 3.4922726154327393,
+ 3.4879646301269531,
+ 3.4725594520568848,
+ 3.4558389186859131,
+ 3.4351828098297119,
+ 3.4284293651580811,
+ 3.4299170970916748,
]
@dataclass
class PreprocessorConfig:
"""Global hyperparameters used to minimize obnoxious kwarg plumbing."""
+
sample_rate = 16000
frame_size_ms = 25
frame_step_ms = 10
@@ -203,10 +204,12 @@ class PreprocessorConfig:
def _hertz_to_mel(frequencies_hertz):
"""Convert hertz to mel."""
- log_fn = math.log if type(frequencies_hertz) in [type(0.0), type(0)
- ] else torch.log
- return _MEL_HIGH_FREQUENCY_Q * log_fn(1.0 + (frequencies_hertz /
- _MEL_BREAK_FREQUENCY_HERTZ))
+ log_fn = (
+ math.log if type(frequencies_hertz) in [type(0.0), type(0)] else torch.log
+ )
+ return _MEL_HIGH_FREQUENCY_Q * log_fn(
+ 1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ)
+ )
def _pad_end_length(num_timesteps, frame_step, frame_size):
@@ -218,28 +221,30 @@ def _pad_end_length(num_timesteps, frame_step, frame_size):
return padded_length - num_timesteps
-def frame(x,
- frame_length: int,
- frame_step: int,
- pad_end: bool = False,
- pad_value: Union[int, float] = 0.0):
+def frame(
+ x,
+ frame_length: int,
+ frame_step: int,
+ pad_end: bool = False,
+ pad_value: Union[int, float] = 0.0,
+):
"""Slides a window and extract values.
- This function extracts `x[:, n:n+frame_length, :]` with sliding `n` with
- stride of `frame_step`, and returns an array `y` with the shape
- `(batch_size, num_frames, frame_length, num_channels)`. Unlike the
- counterpart in Tensorflow (`tf.signal.frame`), this function currently
- does not take `axis` argument, and the input tensor `x` is expected to
- have a shape of `(batch_size, timesteps, channels)`.
- Args:
- x: An input array with `(batch_size, timesteps, channels)`-shape.
- frame_length: The frame length.
- frame_step: The frame hop size.
- pad_end: If True, the end of signal is padded so the window can continue
- sliding while the starting point of the window is in the valid range.
- pad_value: A scalar used as a padding value when `pad_end` is True.
- Returns:
- A tensor with shape `(*, num_frames, frame_length, num_channels)`.
- """
+ This function extracts `x[:, n:n+frame_length, :]` with sliding `n` with
+ stride of `frame_step`, and returns an array `y` with the shape
+ `(batch_size, num_frames, frame_length, num_channels)`. Unlike the
+ counterpart in Tensorflow (`tf.signal.frame`), this function currently
+ does not take `axis` argument, and the input tensor `x` is expected to
+ have a shape of `(batch_size, timesteps, channels)`.
+ Args:
+ x: An input array with `(batch_size, timesteps, channels)`-shape.
+ frame_length: The frame length.
+ frame_step: The frame hop size.
+ pad_end: If True, the end of signal is padded so the window can continue
+ sliding while the starting point of the window is in the valid range.
+ pad_value: A scalar used as a padding value when `pad_end` is True.
+ Returns:
+ A tensor with shape `(*, num_frames, frame_length, num_channels)`.
+ """
num_timesteps = x.shape[1]
if pad_end:
@@ -250,60 +255,67 @@ def frame(x,
return x.permute(0, 1, 3, 2)
-def linear_to_mel_weight_matrix(num_mel_bins: int = 20,
- num_spectrogram_bins: int = 129,
- sample_rate: Union[int, float] = 8000,
- lower_edge_hertz: Union[int, float] = 125.0,
- upper_edge_hertz: Union[int, float] = 3800.0,
- dtype: Any = torch.float32,
- device='cpu'):
+def linear_to_mel_weight_matrix(
+ num_mel_bins: int = 20,
+ num_spectrogram_bins: int = 129,
+ sample_rate: Union[int, float] = 8000,
+ lower_edge_hertz: Union[int, float] = 125.0,
+ upper_edge_hertz: Union[int, float] = 3800.0,
+ dtype: Any = torch.float32,
+ device='cpu',
+):
r"""Pytorch-port of `tf.signal.linear_to_mel_weight_matrix`.
- Args:
- num_mel_bins: Python int. How many bands in the resulting mel spectrum.
- num_spectrogram_bins: An integer `Tensor`. How many bins there are in
- the source spectrogram data, which is understood to be `fft_size // 2 + 1`,
- i.e. the spectrogram only contains the nonredundant FFT bins.
- sample_rate: An integer or float `Tensor`. Samples per second of the
- input signal used to create the spectrogram. Used to figure out the
- frequencies corresponding to each spectrogram bin, which dictates how they
- are mapped into the mel scale.
- lower_edge_hertz: Python float. Lower bound on the frequencies to be
- included in the mel spectrum. This corresponds to the lower edge of the
- lowest triangular band.
- upper_edge_hertz: Python float. The desired top edge of the highest
- frequency band.
- dtype: The `DType` of the result matrix. Must be a floating point type.
- Returns:
- An array of shape `[num_spectrogram_bins, num_mel_bins]`.
- Raises:
- ValueError: If `num_mel_bins`/`num_spectrogram_bins`/`sample_rate` are not
- positive, `lower_edge_hertz` is negative, frequency edges are incorrectly
- ordered, `upper_edge_hertz` is larger than the Nyquist frequency.
- [mel]: https://en.wikipedia.org/wiki/Mel_scale
- """
+ Args:
+ num_mel_bins: Python int. How many bands in the resulting mel spectrum.
+ num_spectrogram_bins: An integer `Tensor`. How many bins there are in
+ the source spectrogram data, which is understood to be `fft_size // 2 + 1`,
+ i.e. the spectrogram only contains the nonredundant FFT bins.
+ sample_rate: An integer or float `Tensor`. Samples per second of the
+ input signal used to create the spectrogram. Used to figure out the
+ frequencies corresponding to each spectrogram bin, which dictates how they
+ are mapped into the mel scale.
+ lower_edge_hertz: Python float. Lower bound on the frequencies to be
+ included in the mel spectrum. This corresponds to the lower edge of the
+ lowest triangular band.
+ upper_edge_hertz: Python float. The desired top edge of the highest
+ frequency band.
+ dtype: The `DType` of the result matrix. Must be a floating point type.
+ Returns:
+ An array of shape `[num_spectrogram_bins, num_mel_bins]`.
+ Raises:
+ ValueError: If `num_mel_bins`/`num_spectrogram_bins`/`sample_rate` are not
+ positive, `lower_edge_hertz` is negative, frequency edges are incorrectly
+ ordered, `upper_edge_hertz` is larger than the Nyquist frequency.
+ [mel]: https://en.wikipedia.org/wiki/Mel_scale
+ """
# Input validator from tensorflow/python/ops/signal/mel_ops.py#L71
if num_mel_bins <= 0:
raise ValueError('num_mel_bins must be positive. Got: %s' % num_mel_bins)
if lower_edge_hertz < 0.0:
- raise ValueError('lower_edge_hertz must be non-negative. Got: %s' %
- lower_edge_hertz)
+ raise ValueError(
+ 'lower_edge_hertz must be non-negative. Got: %s' % lower_edge_hertz
+ )
if lower_edge_hertz >= upper_edge_hertz:
- raise ValueError('lower_edge_hertz %.1f >= upper_edge_hertz %.1f' %
- (lower_edge_hertz, upper_edge_hertz))
+ raise ValueError(
+ 'lower_edge_hertz %.1f >= upper_edge_hertz %.1f'
+ % (lower_edge_hertz, upper_edge_hertz)
+ )
if sample_rate <= 0.0:
raise ValueError('sample_rate must be positive. Got: %s' % sample_rate)
if upper_edge_hertz > sample_rate / 2:
- raise ValueError('upper_edge_hertz must not be larger than the Nyquist '
- 'frequency (sample_rate / 2). Got %s for sample_rate: %s' %
- (upper_edge_hertz, sample_rate))
+ raise ValueError(
+ 'upper_edge_hertz must not be larger than the Nyquist '
+ 'frequency (sample_rate / 2). Got %s for sample_rate: %s'
+ % (upper_edge_hertz, sample_rate)
+ )
# HTK excludes the spectrogram DC bin.
bands_to_zero = 1
nyquist_hertz = sample_rate / 2.0
linear_frequencies = torch.linspace(
- 0.0, nyquist_hertz, num_spectrogram_bins, dtype=dtype,
- device=device)[bands_to_zero:]
+ 0.0, nyquist_hertz, num_spectrogram_bins, dtype=dtype, device=device
+ )[bands_to_zero:]
spectrogram_bins_mel = _hertz_to_mel(linear_frequencies)[:, None]
# Compute num_mel_bins triples of (lower_edge, center, upper_edge). The
@@ -311,11 +323,12 @@ def linear_to_mel_weight_matrix(num_mel_bins: int = 20,
# Accordingly, we divide [lower_edge_hertz, upper_edge_hertz] into
# num_mel_bins + 2 pieces.
edges = torch.linspace(
- _hertz_to_mel(lower_edge_hertz),
- _hertz_to_mel(upper_edge_hertz),
- num_mel_bins + 2,
- dtype=dtype,
- device=device)
+ _hertz_to_mel(lower_edge_hertz),
+ _hertz_to_mel(upper_edge_hertz),
+ num_mel_bins + 2,
+ dtype=dtype,
+ device=device,
+ )
# Split the triples up and reshape them into [1, num_mel_bins] tensors.
lower_edge_mel = edges[:-2][None, :]
@@ -325,13 +338,16 @@ def linear_to_mel_weight_matrix(num_mel_bins: int = 20,
# Calculate lower and upper slopes for every spectrogram bin.
# Line segments are linear in the mel domain, not Hertz.
lower_slopes = (spectrogram_bins_mel - lower_edge_mel) / (
- center_mel - lower_edge_mel)
+ center_mel - lower_edge_mel
+ )
upper_slopes = (upper_edge_mel - spectrogram_bins_mel) / (
- upper_edge_mel - center_mel)
+ upper_edge_mel - center_mel
+ )
# Intersect the line segments with each other and zero.
mel_weights_matrix = torch.minimum(lower_slopes, upper_slopes).clamp(
- min=0.0, max=None)
+ min=0.0, max=None
+ )
# Re-add the zeroed lower bins we sliced out above.
return F.pad(mel_weights_matrix, (0, 0, bands_to_zero, 0))
@@ -339,43 +355,50 @@ def linear_to_mel_weight_matrix(num_mel_bins: int = 20,
def _hanning_greco(win_support, frame_size, dtype, device='cpu'):
"""Add a greco-style hanning window to the graph.
- Note that the Hanning window in Wikipedia is not the same as the Hanning
- window in Greco. The Greco3 Hanning window at 0 is NOT 0, as the wikipedia
- page would indicate. Talkin's explanation was that it was like wasting two
- samples to have the values at the edge of the window to be 0.0 exactly.
- Args:
- win_support: Number of samples for non-zero support in the window
- frame_size: Total size of the window (frame_size >= win_support)
- dtype: TF data type
- Returns:
- Tensor of size frame_size with the window to apply.
- """
+ Note that the Hanning window in Wikipedia is not the same as the Hanning
+ window in Greco. The Greco3 Hanning window at 0 is NOT 0, as the wikipedia
+ page would indicate. Talkin's explanation was that it was like wasting two
+ samples to have the values at the edge of the window to be 0.0 exactly.
+ Args:
+ win_support: Number of samples for non-zero support in the window
+ frame_size: Total size of the window (frame_size >= win_support)
+ dtype: TF data type
+ Returns:
+ Tensor of size frame_size with the window to apply.
+ """
if frame_size < win_support:
raise ValueError(
- 'Provided frame_size = {} is lower than win_support = {}'.format(
- frame_size, win_support))
+ 'Provided frame_size = {} is lower than win_support = {}'.format(
+ frame_size, win_support
+ )
+ )
arg = torch.pi * 2.0 / (win_support)
- hann = 0.5 - (0.5 * torch.cos(
- arg * (torch.arange(win_support, dtype=dtype, device=device) + 0.5)))
+ hann = 0.5 - (
+ 0.5
+ * torch.cos(
+ arg * (torch.arange(win_support, dtype=dtype, device=device) + 0.5)
+ )
+ )
zero_size = frame_size - win_support
return F.pad(hann, (0, zero_size))
def _next_pow_of_two(x: Union[int, float]) -> int:
- return int(2**np.ceil(np.log2(x)))
+ return int(2 ** np.ceil(np.log2(x)))
class SpectrogramFrontend(nn.Module):
- """Layer to convert input audio signals from time domain to frequency domain.
- """
-
- def __init__(self,
- config: PreprocessorConfig = None,
- input_scale_factor: float = 1.0,
- output_log: bool = False,
- dtype=torch.float32,
- device='cpu'):
+ """Layer to convert input audio signals from time domain to frequency domain."""
+
+ def __init__(
+ self,
+ config: PreprocessorConfig = None,
+ input_scale_factor: float = 1.0,
+ output_log: bool = False,
+ dtype=torch.float32,
+ device='cpu',
+ ):
super().__init__()
self.config = config
@@ -384,8 +407,9 @@ def __init__(self,
p = self.config
self._frame_step = int(round(p.sample_rate * p.frame_step_ms / 1000.0))
- self._frame_size = int(round(
- p.sample_rate * p.frame_size_ms / 1000.0)) + 1 # +1 for the preemph
+ self._frame_size = (
+ int(round(p.sample_rate * p.frame_size_ms / 1000.0)) + 1
+ ) # +1 for the preemph
# TF-version has maximum of 512, but it's not always necessary
self.fft_size = _next_pow_of_two(self._frame_size)
@@ -399,23 +423,20 @@ def _hanning_window(frame_size, dtype):
if frame_size % 2 == 0:
# simulate periodic=True in tf.signal.hann_window
return torch.hann_window(
- window_length=frame_size,
- periodic=True,
- dtype=dtype,
- device=device)
+ window_length=frame_size, periodic=True, dtype=dtype, device=device
+ )
else:
return torch.hann_window(
- window_length=frame_size,
- periodic=False,
- dtype=dtype,
- device=device)
+ window_length=frame_size, periodic=False, dtype=dtype, device=device
+ )
self._window_fn = _hanning_window
elif p.window_fn.upper() == 'HANNING_GRECO':
# Greco-compatible hanning window
def f(frame_size, dtype):
return _hanning_greco(
- self._frame_size - 1, frame_size, dtype, device=device)
+ self._frame_size - 1, frame_size, dtype, device=device
+ )
self._window_fn = f
else:
@@ -430,25 +451,31 @@ def f(frame_size, dtype):
def _apply_preemphasis(self, framed_signal):
p = self.config
if p.preemph_htk_flavor:
- return torch.cat([
- framed_signal[:, :, :1, :] * (1. - p.preemph),
- (framed_signal[:, :, 1:-1, :] -
- p.preemph * framed_signal[:, :, :-2, :])
- ],
- dim=2)
+ return torch.cat(
+ [
+ framed_signal[:, :, :1, :] * (1.0 - p.preemph),
+ (
+ framed_signal[:, :, 1:-1, :]
+ - p.preemph * framed_signal[:, :, :-2, :]
+ ),
+ ],
+ dim=2,
+ )
else:
- return (framed_signal[:, :, 1:, :] -
- p.preemph * framed_signal[:, :, :-1, :])
+ return (
+ framed_signal[:, :, 1:, :] - p.preemph * framed_signal[:, :, :-1, :]
+ )
def fprop_paddings(self, input_paddings):
p = self.config
if p.pad_end:
- num_extends = _pad_end_length(input_paddings.shape[1],
- self._frame_step,
- self._frame_size)
+ num_extends = _pad_end_length(
+ input_paddings.shape[1], self._frame_step, self._frame_size
+ )
input_paddings = F.pad(input_paddings, (0, num_extends), value=1.0)
x = input_paddings.unfold(
- dimension=1, size=self._frame_size, step=self._frame_step)
+ dimension=1, size=self._frame_size, step=self._frame_step
+ )
return x.min(dim=2)[0]
def forward(self, inputs, input_paddings):
@@ -467,7 +494,8 @@ def forward(self, inputs, input_paddings):
pcm_audio_chunk = inputs * self.input_scale_factor
framed_signal = frame(
- pcm_audio_chunk, self._frame_size, self._frame_step, pad_end=p.pad_end)
+ pcm_audio_chunk, self._frame_size, self._frame_step, pad_end=p.pad_end
+ )
if p.preemph != 0.0:
preemphasized = self._apply_preemphasis(framed_signal)
@@ -497,12 +525,14 @@ def forward(self, inputs, input_paddings):
class MelFilterbankFrontend(nn.Module):
"""Layer to compute log mel spectograms from input audio signals."""
- def __init__(self,
- config: PreprocessorConfig = None,
- use_divide_stream: bool = True,
- per_bin_mean: Optional[float] = None,
- per_bin_stddev: Optional[float] = None,
- device='cpu'):
+ def __init__(
+ self,
+ config: PreprocessorConfig = None,
+ use_divide_stream: bool = True,
+ per_bin_mean: Optional[float] = None,
+ per_bin_stddev: Optional[float] = None,
+ device='cpu',
+ ):
super().__init__()
self.config = config
@@ -513,7 +543,8 @@ def __init__(self,
input_scale_factor = 2**-15 if self.use_divide_stream else 1.0
self.stft = SpectrogramFrontend(
- p, input_scale_factor=input_scale_factor, output_log=False)
+ p, input_scale_factor=input_scale_factor, output_log=False
+ )
if self.per_bin_mean is None:
per_bin_mean = [0.0] * p.num_bins
@@ -525,10 +556,13 @@ def __init__(self,
else:
per_bin_stddev = self.per_bin_stddev
- self.register_buffer('_normalizer_mean',
- torch.FloatTensor(per_bin_mean)[None, None, :, None])
- self.register_buffer('_normalizer_stddev',
- torch.FloatTensor(per_bin_stddev)[None, None, :, None])
+ self.register_buffer(
+ '_normalizer_mean', torch.FloatTensor(per_bin_mean)[None, None, :, None]
+ )
+ self.register_buffer(
+ '_normalizer_stddev',
+ torch.FloatTensor(per_bin_stddev)[None, None, :, None],
+ )
def forward(self, inputs, input_paddings):
p = self.config
@@ -536,20 +570,24 @@ def forward(self, inputs, input_paddings):
spect, spect_paddings = self.stft(inputs, input_paddings)
mel_weights = linear_to_mel_weight_matrix(
- num_mel_bins=p.num_bins,
- num_spectrogram_bins=spect.shape[2],
- sample_rate=p.sample_rate,
- lower_edge_hertz=p.lower_edge_hertz,
- upper_edge_hertz=p.upper_edge_hertz,
- device=spect.device)
+ num_mel_bins=p.num_bins,
+ num_spectrogram_bins=spect.shape[2],
+ sample_rate=p.sample_rate,
+ lower_edge_hertz=p.lower_edge_hertz,
+ upper_edge_hertz=p.upper_edge_hertz,
+ device=spect.device,
+ )
mel_spectrogram = torch.einsum('fn,btfc->btnc', mel_weights, spect)
logmel_spectrogram = torch.log(
- mel_spectrogram.clamp(min=p.output_floor, max=None))
+ mel_spectrogram.clamp(min=p.output_floor, max=None)
+ )
normalized_logmel_spectrogram = (
- (logmel_spectrogram - self._normalizer_mean) / self._normalizer_stddev)
+ logmel_spectrogram - self._normalizer_mean
+ ) / self._normalizer_stddev
- normalized_logmel_spectrogram = torch.squeeze(normalized_logmel_spectrogram,
- -1)
+ normalized_logmel_spectrogram = torch.squeeze(
+ normalized_logmel_spectrogram, -1
+ )
return normalized_logmel_spectrogram, spect_paddings
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/spectrum_augmenter.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/spectrum_augmenter.py
index 11b93703e..66db657b8 100644
--- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/spectrum_augmenter.py
+++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/spectrum_augmenter.py
@@ -9,19 +9,21 @@
class SpecAug(nn.Module):
"""Layer performs masking prodecure along time and frequency axis.
- The procedure is detailed in https://arxiv.org/abs/1904.08779.
- This is an essential component in speech recognition models that
- helps achieve better word error rates.
- """
-
- def __init__(self,
- freq_mask_count: int = 1,
- freq_mask_max_bins: int = 15,
- time_mask_count: int = 1,
- time_mask_max_frames: int = 50,
- time_mask_max_ratio: float = 1.0,
- time_masks_per_frame: float = 0.0,
- use_dynamic_time_mask_max_frames: bool = False):
+ The procedure is detailed in https://arxiv.org/abs/1904.08779.
+ This is an essential component in speech recognition models that
+ helps achieve better word error rates.
+ """
+
+ def __init__(
+ self,
+ freq_mask_count: int = 1,
+ freq_mask_max_bins: int = 15,
+ time_mask_count: int = 1,
+ time_mask_max_frames: int = 50,
+ time_mask_max_ratio: float = 1.0,
+ time_masks_per_frame: float = 0.0,
+ use_dynamic_time_mask_max_frames: bool = False,
+ ):
super().__init__()
self.freq_mask_count = freq_mask_count
@@ -35,23 +37,26 @@ def __init__(self,
def next_prng_key(self, name='dropout'):
return self.make_rng(name)
- def _get_mask(self,
- batch_size,
- choose_range,
- mask_size,
- max_length=None,
- masks_per_frame=0.0,
- multiplicity=1,
- max_ratio=1.0,
- device='cpu'):
+ def _get_mask(
+ self,
+ batch_size,
+ choose_range,
+ mask_size,
+ max_length=None,
+ masks_per_frame=0.0,
+ multiplicity=1,
+ max_ratio=1.0,
+ device='cpu',
+ ):
# Sample lengths for multiple masks.
if max_length and max_length > 0:
max_length = max_length * torch.ones(batch_size, device=device)
else:
max_length = choose_range * max_ratio
masked_portion = torch.rand(batch_size, multiplicity, device=device)
- masked_frame_size = torch.einsum('b,bm->bm', max_length,
- masked_portion).long()
+ masked_frame_size = torch.einsum(
+ 'b,bm->bm', max_length, masked_portion
+ ).long()
# Make sure the sampled length was smaller than max_ratio * length_bound.
# Note that sampling in this way was biased
# (shorter sequence may over-masked.)
@@ -80,8 +85,9 @@ def _get_mask(self,
# Sum masks with appropriate multiplicity.
if masks_per_frame > 0:
multiplicity_weights = torch.tile(
- torch.arange(multiplicity, device=device).long()[None, ...],
- [batch_size, 1])
+ torch.arange(multiplicity, device=device).long()[None, ...],
+ [batch_size, 1],
+ )
multiplicity_tensor = masks_per_frame * choose_range
multiplicity_weights = (multiplicity_weights < multiplicity_tensor).long()
pre_mask = torch.einsum('bmt,bm->bt', pre_mask, multiplicity_weights)
@@ -99,8 +105,9 @@ def _time_mask(self, inputs, length):
max_ratio = self.time_mask_max_ratio
# If maximum mask length is zero, do nothing.
- if ((time_mask_max_frames == 0 and not use_dynamic_time_mask_max_frames) or
- max_ratio <= 0.0):
+ if (
+ time_mask_max_frames == 0 and not use_dynamic_time_mask_max_frames
+ ) or max_ratio <= 0.0:
return inputs
if multiplicity == 0:
return inputs
@@ -112,14 +119,15 @@ def _time_mask(self, inputs, length):
time_mask_max_frames = None
# Create masks in time direction and apply.
block_arrays = self._get_mask(
- batch_size,
- choose_range=length,
- mask_size=time_length,
- max_length=time_mask_max_frames,
- masks_per_frame=self.time_masks_per_frame,
- multiplicity=multiplicity,
- max_ratio=max_ratio,
- device=inputs.device)
+ batch_size,
+ choose_range=length,
+ mask_size=time_length,
+ max_length=time_mask_max_frames,
+ masks_per_frame=self.time_masks_per_frame,
+ multiplicity=multiplicity,
+ max_ratio=max_ratio,
+ device=inputs.device,
+ )
outputs = torch.einsum('bxy,bx->bxy', inputs, block_arrays)
return outputs
@@ -138,14 +146,15 @@ def _frequency_mask(self, inputs):
choose_range = num_freq * torch.ones(batch_size, device=inputs.device)
# Create masks in frequency direction and apply.
block_arrays = self._get_mask(
- batch_size,
- choose_range=choose_range,
- mask_size=num_freq,
- max_length=freq_mask_max_bins,
- masks_per_frame=0.0,
- multiplicity=multiplicity,
- max_ratio=1.0,
- device=inputs.device)
+ batch_size,
+ choose_range=choose_range,
+ mask_size=num_freq,
+ max_length=freq_mask_max_bins,
+ masks_per_frame=0.0,
+ multiplicity=multiplicity,
+ max_ratio=1.0,
+ device=inputs.device,
+ )
outputs = torch.einsum('bxy,by->bxy', inputs, block_arrays)
return outputs
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py
index dbeabb16c..25416682c 100644
--- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py
+++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/workload.py
@@ -10,15 +10,12 @@
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
-from algoperf import data_utils
-from algoperf import param_utils
-from algoperf import pytorch_utils
-from algoperf import spec
import algoperf.random_utils as prng
-from algoperf.workloads.librispeech_conformer import metrics
-from algoperf.workloads.librispeech_conformer import workload
-from algoperf.workloads.librispeech_conformer.input_pipeline import \
- LibriSpeechDataset
+from algoperf import data_utils, param_utils, pytorch_utils, spec
+from algoperf.workloads.librispeech_conformer import metrics, workload
+from algoperf.workloads.librispeech_conformer.input_pipeline import (
+ LibriSpeechDataset,
+)
from algoperf.workloads.librispeech_conformer.librispeech_pytorch import models
USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_utils.pytorch_setup()
@@ -27,10 +24,9 @@
class LibriSpeechConformerWorkload(workload.BaseLibrispeechWorkload):
-
- def __init__(self,
- tokenizer_vocab_path: Optional[str] = None,
- use_specaug: bool = True) -> None:
+ def __init__(
+ self, tokenizer_vocab_path: Optional[str] = None, use_specaug: bool = True
+ ) -> None:
super().__init__()
self.tokenizer = metrics.load_tokenizer(tokenizer_vocab_path)
self.use_specaug = use_specaug
@@ -42,7 +38,8 @@ def __init__(self,
def eval_num_workers(self) -> int:
if self._eval_num_workers is None:
raise ValueError(
- 'eval_num_workers property must be set before workload is used.')
+ 'eval_num_workers property must be set before workload is used.'
+ )
return self._eval_num_workers
@eval_num_workers.setter
@@ -74,11 +71,13 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
else:
activation_function_name = 'swish'
model = models.ConformerEncoderDecoder(
- models.ConformerConfig(
- use_specaug=self.use_specaug,
- attention_temperature=self.attention_temperature,
- use_post_layer_norm=self.use_post_layer_norm,
- activation_function_name=activation_function_name))
+ models.ConformerConfig(
+ use_specaug=self.use_specaug,
+ attention_temperature=self.attention_temperature,
+ use_post_layer_norm=self.use_post_layer_norm,
+ activation_function_name=activation_function_name,
+ )
+ )
self.ctc_loss = torch.nn.CTCLoss(blank=0, reduction='none')
models.initialize(model)
self._param_shapes = param_utils.pytorch_param_shapes(model)
@@ -97,14 +96,14 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key in ['lin.weight', 'lin.bias']
def model_fn(
- self,
- params: spec.ParameterContainer,
- augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- mode: spec.ForwardPassMode,
- rng: spec.RandomState,
- update_batch_norm: bool,
- dropout_rate: float = models.DROPOUT_RATE
+ self,
+ params: spec.ParameterContainer,
+ augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ mode: spec.ForwardPassMode,
+ rng: spec.RandomState,
+ update_batch_norm: bool,
+ dropout_rate: float = models.DROPOUT_RATE,
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del rng
@@ -115,30 +114,33 @@ def model_fn(
if mode == spec.ForwardPassMode.TRAIN:
model.train()
model.apply(
- functools.partial(
- pytorch_utils.update_batch_norm_fn,
- update_batch_norm=update_batch_norm))
+ functools.partial(
+ pytorch_utils.update_batch_norm_fn,
+ update_batch_norm=update_batch_norm,
+ )
+ )
contexts = {
- spec.ForwardPassMode.EVAL: torch.no_grad,
- spec.ForwardPassMode.TRAIN: contextlib.nullcontext,
+ spec.ForwardPassMode.EVAL: torch.no_grad,
+ spec.ForwardPassMode.TRAIN: contextlib.nullcontext,
}
with contexts[mode]():
inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs']
- logits, logits_paddings = model(inputs.to(DEVICE),
- input_paddings.to(DEVICE),
- dropout_rate=dropout_rate)
+ logits, logits_paddings = model(
+ inputs.to(DEVICE), input_paddings.to(DEVICE), dropout_rate=dropout_rate
+ )
return (logits, logits_paddings), None
def _build_input_queue(
- self,
- data_rng: spec.RandomState,
- split: str,
- data_dir: str,
- global_batch_size: int,
- cache: Optional[bool] = None,
- repeat_final_dataset: Optional[bool] = None,
- num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]:
+ self,
+ data_rng: spec.RandomState,
+ split: str,
+ data_dir: str,
+ global_batch_size: int,
+ cache: Optional[bool] = None,
+ repeat_final_dataset: Optional[bool] = None,
+ num_batches: Optional[int] = None,
+ ) -> Iterator[Dict[str, spec.Tensor]]:
del cache
del repeat_final_dataset
del num_batches
@@ -157,7 +159,7 @@ def _build_input_queue(
if split == 'eval_train':
indices = list(range(len(ds)))
random.Random(int(data_rng[0])).shuffle(indices)
- ds = torch.utils.data.Subset(ds, indices[:self.num_eval_train_examples])
+ ds = torch.utils.data.Subset(ds, indices[: self.num_eval_train_examples])
sampler = None
if USE_PYTORCH_DDP:
@@ -168,31 +170,36 @@ def _build_input_queue(
if USE_PYTORCH_DDP:
if is_train:
sampler = torch.utils.data.distributed.DistributedSampler(
- ds, num_replicas=N_GPUS, rank=RANK, shuffle=True)
+ ds, num_replicas=N_GPUS, rank=RANK, shuffle=True
+ )
else:
sampler = data_utils.DistributedEvalSampler(
- ds, num_replicas=N_GPUS, rank=RANK, shuffle=False)
+ ds, num_replicas=N_GPUS, rank=RANK, shuffle=False
+ )
dataloader = torch.utils.data.DataLoader(
- ds,
- batch_size=ds_iter_batch_size,
- shuffle=not USE_PYTORCH_DDP and is_train,
- sampler=sampler,
- num_workers=4,
- pin_memory=True,
- drop_last=is_train)
+ ds,
+ batch_size=ds_iter_batch_size,
+ shuffle=not USE_PYTORCH_DDP and is_train,
+ sampler=sampler,
+ num_workers=4,
+ pin_memory=True,
+ drop_last=is_train,
+ )
dataloader = data_utils.cycle(
- dataloader, custom_sampler=USE_PYTORCH_DDP, use_mixup=False)
+ dataloader, custom_sampler=USE_PYTORCH_DDP, use_mixup=False
+ )
return dataloader
# Does NOT apply regularization, which is left to the submitter to do in
# `update_params`.
def loss_fn(
- self,
- label_batch: Tuple[spec.Tensor, spec.Tensor], # (label_batch, padding)
- logits_batch: Tuple[spec.Tensor, spec.Tensor], # (logits_batch, padding)
- mask_batch: Optional[spec.Tensor] = None,
- label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable
+ self,
+ label_batch: Tuple[spec.Tensor, spec.Tensor], # (label_batch, padding)
+ logits_batch: Tuple[spec.Tensor, spec.Tensor], # (logits_batch, padding)
+ mask_batch: Optional[spec.Tensor] = None,
+ label_smoothing: float = 0.0,
+ ) -> Dict[str, spec.Tensor]: # differentiable
"""Evaluate the (masked) loss function at (label_batch, logits_batch).
Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of
@@ -206,10 +213,8 @@ def loss_fn(
input_lengths = torch.einsum('bh->b', 1 - logit_paddings).long()
target_lengths = torch.einsum('bh->b', 1 - target_paddings).long()
per_example_losses = self.ctc_loss(
- logprobs.permute(1, 0, 2),
- targets.long(),
- input_lengths,
- target_lengths)
+ logprobs.permute(1, 0, 2), targets.long(), input_lengths, target_lengths
+ )
# mask_batch is assumed to be shape [batch].
if mask_batch is not None:
per_example_losses *= mask_batch
@@ -220,21 +225,22 @@ def loss_fn(
summed_loss = per_example_losses.sum()
n_valid_examples = max(n_valid_examples, 1)
return {
- 'summed': summed_loss,
- 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE),
- 'per_example': per_example_losses,
+ 'summed': summed_loss,
+ 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE),
+ 'per_example': per_example_losses,
}
def greedy_decode(
- self, logits: spec.Tensor,
- logit_paddings: spec.Tensor) -> Tuple[spec.Tensor, spec.Tensor]:
+ self, logits: spec.Tensor, logit_paddings: spec.Tensor
+ ) -> Tuple[spec.Tensor, spec.Tensor]:
framewise_tokens = logits.max(dim=-1)[1]
framewise_tokens = framewise_tokens * (1 - logit_paddings)
# Add sentinel because unique_consecutive will flatten array
# and then compute the unique.
framewise_tokens = torch.cat(
- [framewise_tokens, -torch.ones_like(framewise_tokens[:, 0:1])], dim=1)
+ [framewise_tokens, -torch.ones_like(framewise_tokens[:, 0:1])], dim=1
+ )
_, indices = torch.unique_consecutive(framewise_tokens, return_inverse=True)
indices -= indices.min(dim=1, keepdims=True)[0]
result = torch.zeros_like(framewise_tokens)
@@ -247,11 +253,12 @@ def greedy_decode(
# Remove blanks (id = 0).
blank_id = 0
fin_result = torch.zeros_like(result)
- idxs = torch.arange(
- fin_result.numel(), device=result.device).view(*fin_result.shape)
- mask = torch.arange(
- fin_result.shape[1], device=result.device).view(
- 1, -1) < result.count_nonzero(dim=1).view(-1, 1)
+ idxs = torch.arange(fin_result.numel(), device=result.device).view(
+ *fin_result.shape
+ )
+ mask = torch.arange(fin_result.shape[1], device=result.device).view(
+ 1, -1
+ ) < result.count_nonzero(dim=1).view(-1, 1)
fin_result.view(-1)[idxs[mask != 0]] = result[result != blank_id]
padding = fin_result == 0
return fin_result, padding
@@ -265,29 +272,31 @@ def sync_sd(self, params: spec.ParameterContainer) -> None:
sd[k] = sd[k] / N_GPUS
params.load_state_dict(sd)
- def _eval_model_on_split(self,
- split: str,
- num_examples: int,
- global_batch_size: int,
- params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- rng: spec.RandomState,
- data_dir: str,
- global_step: int = 0) -> Dict[str, float]:
+ def _eval_model_on_split(
+ self,
+ split: str,
+ num_examples: int,
+ global_batch_size: int,
+ params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ rng: spec.RandomState,
+ data_dir: str,
+ global_step: int = 0,
+ ) -> Dict[str, float]:
"""Run a full evaluation of the model."""
del global_step
data_rng, model_rng = prng.split(rng, 2)
if split not in self._eval_iters:
# These iterators repeat indefinitely.
- self._eval_iters[split] = (
- self._build_input_queue(
- data_rng, split, data_dir, global_batch_size=global_batch_size))
+ self._eval_iters[split] = self._build_input_queue(
+ data_rng, split, data_dir, global_batch_size=global_batch_size
+ )
total_metrics = {
- 'loss': torch.tensor(0., device=DEVICE),
- 'lengths': torch.tensor(0., device=DEVICE),
- 'word_errors': torch.tensor(0., device=DEVICE),
- 'num_words': torch.tensor(0., device=DEVICE),
+ 'loss': torch.tensor(0.0, device=DEVICE),
+ 'lengths': torch.tensor(0.0, device=DEVICE),
+ 'word_errors': torch.tensor(0.0, device=DEVICE),
+ 'num_words': torch.tensor(0.0, device=DEVICE),
}
num_batches = int(math.ceil(num_examples / global_batch_size))
if self.requires_sync_before_eval:
@@ -296,48 +305,50 @@ def _eval_model_on_split(self,
batch = next(self._eval_iters[split])
(logits, logits_padding), _ = self.model_fn(
- params,
- batch,
- model_state,
- spec.ForwardPassMode.EVAL,
- model_rng,
- update_batch_norm=False)
+ params,
+ batch,
+ model_state,
+ spec.ForwardPassMode.EVAL,
+ model_rng,
+ update_batch_norm=False,
+ )
decoded, decoded_paddings = self.greedy_decode(logits, logits_padding)
targets, target_paddings = batch['targets']
word_errors, num_words = metrics.compute_wer(
- decoded=decoded.cpu().numpy(),
- decoded_paddings=decoded_paddings.cpu().numpy(),
- targets=targets.cpu().numpy(),
- target_paddings=target_paddings.cpu().numpy(),
- tokenizer=self.tokenizer)
+ decoded=decoded.cpu().numpy(),
+ decoded_paddings=decoded_paddings.cpu().numpy(),
+ targets=targets.cpu().numpy(),
+ target_paddings=target_paddings.cpu().numpy(),
+ tokenizer=self.tokenizer,
+ )
loss = self.loss_fn((targets, target_paddings), (logits, logits_padding))
summed_loss = loss['summed']
lengths = loss['n_valid_examples']
batch_metrics = {
- 'loss': summed_loss,
- 'lengths': lengths,
- 'word_errors': word_errors,
- 'num_words': num_words,
+ 'loss': summed_loss,
+ 'lengths': lengths,
+ 'word_errors': word_errors,
+ 'num_words': num_words,
}
total_metrics = {
- k: v + batch_metrics[k] for k, v in total_metrics.items()
+ k: v + batch_metrics[k] for k, v in total_metrics.items()
}
if USE_PYTORCH_DDP:
for metric in total_metrics.values():
dist.all_reduce(metric)
return {
- 'ctc_loss':
- float(total_metrics['loss'].item() /
- total_metrics['lengths'].item()),
- 'wer':
- float(total_metrics['word_errors'].item() /
- total_metrics['num_words'].item()),
+ 'ctc_loss': float(
+ total_metrics['loss'].item() / total_metrics['lengths'].item()
+ ),
+ 'wer': float(
+ total_metrics['word_errors'].item() / total_metrics['num_words'].item()
+ ),
}
class LibriSpeechConformerAttentionTemperatureWorkload(
- LibriSpeechConformerWorkload):
-
+ LibriSpeechConformerWorkload
+):
@property
def attention_temperature(self) -> float:
return 1.6
@@ -352,7 +363,6 @@ def test_target_value(self) -> float:
class LibriSpeechConformerLayerNormWorkload(LibriSpeechConformerWorkload):
-
@property
def use_post_layer_norm(self) -> bool:
return False
@@ -367,7 +377,6 @@ def test_target_value(self) -> float:
class LibriSpeechConformerGeluWorkload(LibriSpeechConformerWorkload):
-
@property
def use_gelu(self) -> bool:
return True
diff --git a/algoperf/workloads/librispeech_conformer/metrics.py b/algoperf/workloads/librispeech_conformer/metrics.py
index de74cfe1b..d5c826575 100644
--- a/algoperf/workloads/librispeech_conformer/metrics.py
+++ b/algoperf/workloads/librispeech_conformer/metrics.py
@@ -15,17 +15,20 @@ def average_ctc_loss():
@flax.struct.dataclass
class _Metric(metrics.Metric):
"""Applies `fun` and computes the average."""
+
total: np.float32
weight: np.float32
@classmethod
def from_model_output(cls, loss_dict, **_):
return cls(
- total=loss_dict['summed'], weight=loss_dict['n_valid_examples'])
+ total=loss_dict['summed'], weight=loss_dict['n_valid_examples']
+ )
def merge(self, other):
return type(self)(
- total=self.total + other.total, weight=self.weight + other.weight)
+ total=self.total + other.total, weight=self.weight + other.weight
+ )
def compute(self):
return self.total / self.weight
@@ -74,9 +77,10 @@ def edit_distance(source, target):
# possibilities and find minimum.
else:
distance[i][j] = 1 + min(
- distance[i][j - 1], # Insert
- distance[i - 1][j], # Remove
- distance[i - 1][j - 1]) # Replace
+ distance[i][j - 1], # Insert
+ distance[i - 1][j], # Remove
+ distance[i - 1][j - 1],
+ ) # Replace
return distance[num_source_words][num_target_words]
@@ -109,17 +113,20 @@ def compute_wer(decoded, decoded_paddings, targets, target_paddings, tokenizer):
return word_errors, num_words
-def load_tokenizer(model_path: str,
- add_bos: bool = False,
- add_eos: bool = True,
- reverse: bool = False):
+def load_tokenizer(
+ model_path: str,
+ add_bos: bool = False,
+ add_eos: bool = True,
+ reverse: bool = False,
+):
"""Load a tf-text SentencePiece tokenizer from given model filepath."""
if model_path is None:
return None
with gfile.GFile(model_path, 'rb') as model_fp:
sp_model = model_fp.read()
sp_tokenizer = tftxt.SentencepieceTokenizer(
- model=sp_model, add_bos=add_bos, add_eos=add_eos, reverse=reverse)
+ model=sp_model, add_bos=add_bos, add_eos=add_eos, reverse=reverse
+ )
return sp_tokenizer
@@ -128,8 +135,10 @@ def wer(tokenizer_vocab_path):
@flax.struct.dataclass
class WER(
- metrics.CollectingMetric.from_outputs(
- ('decoded', 'decoded_paddings', 'targets', 'target_paddings'))):
+ metrics.CollectingMetric.from_outputs(
+ ('decoded', 'decoded_paddings', 'targets', 'target_paddings')
+ )
+ ):
"""Computes the mean average precision for a binary classifier on CPU."""
def compute(self):
@@ -144,7 +153,8 @@ def compute(self):
values['decoded_paddings'],
values['targets'].astype(np.int32),
values['target_paddings'],
- tokenizer)
+ tokenizer,
+ )
return word_errors / num_words
@@ -153,4 +163,5 @@ def compute(self):
def get_metrics_bundle(tokenizer_vocab_path):
return metrics.Collection.create(
- ctc_loss=average_ctc_loss(), wer=wer(tokenizer_vocab_path))
+ ctc_loss=average_ctc_loss(), wer=wer(tokenizer_vocab_path)
+ )
diff --git a/algoperf/workloads/librispeech_conformer/workload.py b/algoperf/workloads/librispeech_conformer/workload.py
index 94f01dd97..791270719 100644
--- a/algoperf/workloads/librispeech_conformer/workload.py
+++ b/algoperf/workloads/librispeech_conformer/workload.py
@@ -5,7 +5,6 @@
class BaseLibrispeechWorkload(spec.Workload):
-
_num_outputs: int = 1024
@property
@@ -25,8 +24,9 @@ def use_gelu(self) -> bool:
def attention_temperature(self) -> float:
raise NotImplementedError
- def has_reached_validation_target(self, eval_result: Dict[str,
- float]) -> bool:
+ def has_reached_validation_target(
+ self, eval_result: Dict[str, float]
+ ) -> bool:
return eval_result['validation/wer'] < self.validation_target_value
@property
@@ -53,8 +53,9 @@ def num_eval_train_examples(self) -> int:
# Round up from num_validation_examples (which is the default for
# num_eval_train_examples) to the next multiple of eval_batch_size, so that
# we don't have to extract the correctly sized subset of the training data.
- rounded_up_multiple = math.ceil(self.num_validation_examples /
- self.eval_batch_size)
+ rounded_up_multiple = math.ceil(
+ self.num_validation_examples / self.eval_batch_size
+ )
return rounded_up_multiple * self.eval_batch_size
@property
diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py
index 262fc1a95..225852b28 100644
--- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py
+++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/models.py
@@ -10,17 +10,19 @@
from typing import Any, Dict, List, Optional, Tuple, Union
+import jax
+import jax.numpy as jnp
from flax import linen as nn
from flax import struct
-import jax
from jax.experimental import rnn
-import jax.numpy as jnp
from algoperf.jax_utils import Dropout
-from algoperf.workloads.librispeech_conformer.librispeech_jax import \
- librispeech_preprocessor as preprocessor
-from algoperf.workloads.librispeech_conformer.librispeech_jax import \
- spectrum_augmenter
+from algoperf.workloads.librispeech_conformer.librispeech_jax import (
+ librispeech_preprocessor as preprocessor,
+)
+from algoperf.workloads.librispeech_conformer.librispeech_jax import (
+ spectrum_augmenter,
+)
Array = jnp.ndarray
StateType = Union[Array, Tuple[Array, ...]]
@@ -37,6 +39,7 @@
@struct.dataclass
class DeepspeechConfig:
"""Global hyperparameters used to minimize obnoxious kwarg plumbing."""
+
vocab_size: int = 1024
dtype: Any = jnp.float32
encoder_dim: int = 512
@@ -68,6 +71,7 @@ class Subsample(nn.Module):
encoder_dim: model dimension of conformer.
input_dropout_rate: dropout rate for inputs.
"""
+
config: DeepspeechConfig
@nn.compact
@@ -76,38 +80,40 @@ def __call__(self, inputs, output_paddings, train, dropout_rate=DROPOUT_RATE):
outputs = jnp.expand_dims(inputs, axis=-1)
outputs, output_paddings = Conv2dSubsampling(
- encoder_dim=config.encoder_dim,
- dtype=config.dtype,
- batch_norm_momentum=config.batch_norm_momentum,
- batch_norm_epsilon=config.batch_norm_epsilon,
- input_channels=1,
- output_channels=config.encoder_dim,
- use_tanh=config.use_tanh
- )(outputs, output_paddings, train)
+ encoder_dim=config.encoder_dim,
+ dtype=config.dtype,
+ batch_norm_momentum=config.batch_norm_momentum,
+ batch_norm_epsilon=config.batch_norm_epsilon,
+ input_channels=1,
+ output_channels=config.encoder_dim,
+ use_tanh=config.use_tanh,
+ )(outputs, output_paddings, train)
outputs, output_paddings = Conv2dSubsampling(
- encoder_dim=config.encoder_dim,
- dtype=config.dtype,
- batch_norm_momentum=config.batch_norm_momentum,
- batch_norm_epsilon=config.batch_norm_epsilon,
- input_channels=config.encoder_dim,
- output_channels=config.encoder_dim,
- use_tanh=config.use_tanh)(outputs, output_paddings, train)
+ encoder_dim=config.encoder_dim,
+ dtype=config.dtype,
+ batch_norm_momentum=config.batch_norm_momentum,
+ batch_norm_epsilon=config.batch_norm_epsilon,
+ input_channels=config.encoder_dim,
+ output_channels=config.encoder_dim,
+ use_tanh=config.use_tanh,
+ )(outputs, output_paddings, train)
batch_size, subsampled_lengths, subsampled_dims, channels = outputs.shape
outputs = jnp.reshape(
- outputs, (batch_size, subsampled_lengths, subsampled_dims * channels))
+ outputs, (batch_size, subsampled_lengths, subsampled_dims * channels)
+ )
outputs = nn.Dense(
- config.encoder_dim,
- use_bias=True,
- kernel_init=nn.initializers.xavier_uniform())(
- outputs)
+ config.encoder_dim,
+ use_bias=True,
+ kernel_init=nn.initializers.xavier_uniform(),
+ )(outputs)
- outputs = Dropout(
- rate=dropout_rate, deterministic=not train)(
- outputs, rate=dropout_rate)
+ outputs = Dropout(rate=dropout_rate, deterministic=not train)(
+ outputs, rate=dropout_rate
+ )
return outputs, output_paddings
@@ -119,6 +125,7 @@ class Conv2dSubsampling(nn.Module):
2) Also performs strided convolution over input_paddings to return the correct
paddings for downstream layers.
"""
+
input_channels: int = 0
output_channels: int = 0
filter_stride: List[int] = (2, 2)
@@ -131,24 +138,26 @@ class Conv2dSubsampling(nn.Module):
def setup(self):
self.filter_shape = (3, 3, self.input_channels, self.output_channels)
- self.kernel = self.param('kernel',
- nn.initializers.xavier_uniform(),
- self.filter_shape)
+ self.kernel = self.param(
+ 'kernel', nn.initializers.xavier_uniform(), self.filter_shape
+ )
self.bias = self.param(
- 'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels)
+ 'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels
+ )
@nn.compact
def __call__(self, inputs, paddings, train):
# Computing strided convolution to subsample inputs.
feature_group_count = inputs.shape[3] // self.filter_shape[2]
outputs = jax.lax.conv_general_dilated(
- lhs=inputs,
- rhs=self.kernel,
- window_strides=self.filter_stride,
- padding=self.padding,
- rhs_dilation=(1, 1),
- dimension_numbers=('NHWC', 'HWIO', 'NHWC'),
- feature_group_count=feature_group_count)
+ lhs=inputs,
+ rhs=self.kernel,
+ window_strides=self.filter_stride,
+ padding=self.padding,
+ rhs_dilation=(1, 1),
+ dimension_numbers=('NHWC', 'HWIO', 'NHWC'),
+ feature_group_count=feature_group_count,
+ )
outputs += jnp.reshape(self.bias, (1,) * (outputs.ndim - 1) + (-1,))
@@ -163,48 +172,49 @@ def __call__(self, inputs, paddings, train):
pad_len = (input_length + stride - 1) // stride * stride - input_length
out_padding = jax.lax.conv_general_dilated(
- lhs=paddings[:, :, None],
- rhs=jnp.ones([1, 1, 1]),
- window_strides=self.filter_stride[:1],
- padding=[(0, pad_len)],
- dimension_numbers=('NHC', 'HIO', 'NHC'))
+ lhs=paddings[:, :, None],
+ rhs=jnp.ones([1, 1, 1]),
+ window_strides=self.filter_stride[:1],
+ padding=[(0, pad_len)],
+ dimension_numbers=('NHC', 'HIO', 'NHC'),
+ )
out_padding = jnp.squeeze(out_padding, axis=-1)
# Mask outputs by correct paddings to ensure padded elements in inputs map
# to padded value in outputs.
- outputs = outputs * (1.0 -
- jnp.expand_dims(jnp.expand_dims(out_padding, -1), -1))
+ outputs = outputs * (
+ 1.0 - jnp.expand_dims(jnp.expand_dims(out_padding, -1), -1)
+ )
return outputs, out_padding
class FeedForwardModule(nn.Module):
"""Feedforward block of conformer layer."""
+
config: DeepspeechConfig
@nn.compact
- def __call__(self,
- inputs,
- input_paddings=None,
- train=False,
- dropout_rate=DROPOUT_RATE):
+ def __call__(
+ self, inputs, input_paddings=None, train=False, dropout_rate=DROPOUT_RATE
+ ):
padding_mask = jnp.expand_dims(1 - input_paddings, -1)
config = self.config
if config.layernorm_everywhere:
inputs = LayerNorm(config.encoder_dim)(inputs)
else:
- inputs = BatchNorm(config.encoder_dim,
- config.dtype,
- config.batch_norm_momentum,
- config.batch_norm_epsilon)(inputs,
- input_paddings,
- train)
- inputs = nn.Dense(
+ inputs = BatchNorm(
config.encoder_dim,
- use_bias=True,
- kernel_init=nn.initializers.xavier_uniform())(
- inputs)
+ config.dtype,
+ config.batch_norm_momentum,
+ config.batch_norm_epsilon,
+ )(inputs, input_paddings, train)
+ inputs = nn.Dense(
+ config.encoder_dim,
+ use_bias=True,
+ kernel_init=nn.initializers.xavier_uniform(),
+ )(inputs)
if config.use_tanh:
inputs = nn.tanh(inputs)
else:
@@ -212,7 +222,8 @@ def __call__(self,
inputs *= padding_mask
inputs = Dropout(rate=dropout_rate)(
- inputs, deterministic=not train, rate=dropout_rate)
+ inputs, deterministic=not train, rate=dropout_rate
+ )
return inputs
@@ -227,6 +238,7 @@ class LayerNorm(nn.Module):
zeros, this differs from default flax implementation of multiplying by scale
and initializing to ones.
"""
+
dim: int = 0
epsilon: float = 1e-6
@@ -240,7 +252,7 @@ def __call__(self, inputs):
var = jnp.mean(jnp.square(inputs - mean), axis=-1, keepdims=True)
normed_inputs = (inputs - mean) * jax.lax.rsqrt(var + self.epsilon)
- normed_inputs *= (1 + self.scale)
+ normed_inputs *= 1 + self.scale
normed_inputs += self.bias
return normed_inputs
@@ -258,6 +270,7 @@ class BatchNorm(nn.Module):
and the corresponding defaults for momentum and epsilon have been copied over
from lingvo.
"""
+
encoder_dim: int = 0
dtype: Any = jnp.float32
batch_norm_momentum: float = 0.999
@@ -267,14 +280,12 @@ def setup(self):
dim = self.encoder_dim
dtype = self.dtype
- self.ra_mean = self.variable('batch_stats',
- 'mean',
- lambda s: jnp.zeros(s, dtype),
- dim)
- self.ra_var = self.variable('batch_stats',
- 'var',
- lambda s: jnp.ones(s, dtype),
- dim)
+ self.ra_mean = self.variable(
+ 'batch_stats', 'mean', lambda s: jnp.zeros(s, dtype), dim
+ )
+ self.ra_var = self.variable(
+ 'batch_stats', 'var', lambda s: jnp.ones(s, dtype), dim
+ )
self.gamma = self.param('scale', nn.initializers.zeros, dim, dtype)
self.beta = self.param('bias', nn.initializers.zeros, dim, dtype)
@@ -303,7 +314,8 @@ def __call__(self, inputs, input_paddings=None, train=False):
mask = 1.0 - padding
sum_v = jnp.sum(inputs * mask, axis=reduce_over_dims, keepdims=True)
count_v = jnp.sum(
- jnp.ones_like(inputs) * mask, axis=reduce_over_dims, keepdims=True)
+ jnp.ones_like(inputs) * mask, axis=reduce_over_dims, keepdims=True
+ )
sum_v = jax.lax.psum(sum_v, axis_name='batch')
count_v = jax.lax.psum(count_v, axis_name='batch')
@@ -340,15 +352,14 @@ class CudnnLSTM(nn.Module):
@nn.compact
def __call__(
- self,
- inputs: Array,
- segmentation_mask: Optional[Array] = None,
- return_carry: Optional[bool] = None,
- deterministic: bool = False,
- initial_states: Optional[Tuple[Array, Array]] = None,
- use_cuda: bool = True,
+ self,
+ inputs: Array,
+ segmentation_mask: Optional[Array] = None,
+ return_carry: Optional[bool] = None,
+ deterministic: bool = False,
+ initial_states: Optional[Tuple[Array, Array]] = None,
+ use_cuda: bool = True,
) -> Union[Array, Tuple[Array, Carry]]:
-
if jax.devices()[0].platform != 'gpu':
use_cuda = False
@@ -358,22 +369,22 @@ def __call__(
dropout = 0.0 if deterministic else self.dropout_rate
weights = self.param(
- 'weights',
- rnn.init_lstm_weight,
- input_size,
- self.features,
- self.num_layers,
- self.bidirectional,
+ 'weights',
+ rnn.init_lstm_weight,
+ input_size,
+ self.features,
+ self.num_layers,
+ self.bidirectional,
)
if initial_states is None:
h_0 = jnp.zeros(
- (num_directions * self.num_layers, batch_size, self.features),
- jnp.float32,
+ (num_directions * self.num_layers, batch_size, self.features),
+ jnp.float32,
)
c_0 = jnp.zeros(
- (num_directions * self.num_layers, batch_size, self.features),
- jnp.float32,
+ (num_directions * self.num_layers, batch_size, self.features),
+ jnp.float32,
)
else:
h_0, c_0 = initial_states
@@ -385,20 +396,35 @@ def __call__(
if use_cuda:
y, h, c = rnn.lstm(
- x=inputs, h_0=h_0, c_0=c_0, weights=weights,
- seq_lengths=seq_lengths, input_size=input_size,
- hidden_size=self.features, num_layers=self.num_layers,
- dropout=dropout, bidirectional=self.bidirectional,
+ x=inputs,
+ h_0=h_0,
+ c_0=c_0,
+ weights=weights,
+ seq_lengths=seq_lengths,
+ input_size=input_size,
+ hidden_size=self.features,
+ num_layers=self.num_layers,
+ dropout=dropout,
+ bidirectional=self.bidirectional,
)
else:
weight_ih, weight_hh, bias_ih, bias_hh = self.unpack_weights(
- weights, input_size)
+ weights, input_size
+ )
y, h, c = rnn.lstm_ref(
- x=inputs, h_0=h_0, c_0=c_0, W_ih=weight_ih, W_hh=weight_hh,
- b_ih=bias_ih, b_hh=bias_hh, seq_lengths=seq_lengths,
- input_size=input_size, hidden_size=self.features,
- num_layers=self.num_layers, dropout=dropout,
- bidirectional=self.bidirectional,
+ x=inputs,
+ h_0=h_0,
+ c_0=c_0,
+ W_ih=weight_ih,
+ W_hh=weight_hh,
+ b_ih=bias_ih,
+ b_hh=bias_hh,
+ seq_lengths=seq_lengths,
+ input_size=input_size,
+ hidden_size=self.features,
+ num_layers=self.num_layers,
+ dropout=dropout,
+ bidirectional=self.bidirectional,
)
if return_carry:
@@ -408,21 +434,22 @@ def __call__(
@nn.nowrap
def unpack_weights(
- self, weights: Array, input_size: int
+ self, weights: Array, input_size: int
) -> Tuple[
- Dict[int, Array], Dict[int, Array], Dict[int, Array], Dict[int, Array]]:
+ Dict[int, Array], Dict[int, Array], Dict[int, Array], Dict[int, Array]
+ ]:
return jax.experimental.rnn.unpack_lstm_weights(
- weights,
- input_size,
- self.features,
- self.num_layers,
- self.bidirectional,
+ weights,
+ input_size,
+ self.features,
+ self.num_layers,
+ self.bidirectional,
)
class BatchRNN(nn.Module):
- """Implements a single deepspeech encoder layer.
- """
+ """Implements a single deepspeech encoder layer."""
+
config: DeepspeechConfig
@nn.compact
@@ -432,16 +459,17 @@ def __call__(self, inputs, input_paddings, train):
if config.layernorm_everywhere:
inputs = LayerNorm(config.encoder_dim)(inputs)
else:
- inputs = BatchNorm(config.encoder_dim,
- config.dtype,
- config.batch_norm_momentum,
- config.batch_norm_epsilon)(inputs,
- input_paddings,
- train)
+ inputs = BatchNorm(
+ config.encoder_dim,
+ config.dtype,
+ config.batch_norm_momentum,
+ config.batch_norm_epsilon,
+ )(inputs, input_paddings, train)
output = CudnnLSTM(
- features=config.encoder_dim // 2,
- bidirectional=config.bidirectional,
- num_layers=1)(inputs, input_paddings)
+ features=config.encoder_dim // 2,
+ bidirectional=config.bidirectional,
+ num_layers=1,
+ )(inputs, input_paddings)
return output
@@ -453,18 +481,19 @@ class Deepspeech(nn.Module):
for each time step. The output is then fed into a CTC loss which eliminates
the need for alignment with targets.
"""
+
config: DeepspeechConfig
def setup(self):
config = self.config
self.specaug = spectrum_augmenter.SpecAug(
- freq_mask_count=config.freq_mask_count,
- freq_mask_max_bins=config.freq_mask_max_bins,
- time_mask_count=config.time_mask_count,
- time_mask_max_frames=config.time_mask_max_frames,
- time_mask_max_ratio=config.time_mask_max_ratio,
- time_masks_per_frame=config.time_masks_per_frame,
- use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames
+ freq_mask_count=config.freq_mask_count,
+ freq_mask_max_bins=config.freq_mask_max_bins,
+ time_mask_count=config.time_mask_count,
+ time_mask_max_frames=config.time_mask_max_frames,
+ time_mask_max_ratio=config.time_mask_max_ratio,
+ time_masks_per_frame=config.time_masks_per_frame,
+ use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames,
)
@nn.compact
@@ -477,10 +506,10 @@ def __call__(self, inputs, input_paddings, train, dropout_rate=DROPOUT_RATE):
# Compute normalized log mel spectrograms from input audio signal.
preprocessing_config = preprocessor.LibrispeechPreprocessingConfig()
outputs, output_paddings = preprocessor.MelFilterbankFrontend(
- preprocessing_config,
- per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR,
- per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR)(outputs,
- output_paddings)
+ preprocessing_config,
+ per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR,
+ per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR,
+ )(outputs, output_paddings)
# Ablate random parts of input along temporal and frequency dimension
# following the specaug procedure in https://arxiv.org/abs/1904.08779.
@@ -488,9 +517,9 @@ def __call__(self, inputs, input_paddings, train, dropout_rate=DROPOUT_RATE):
outputs, output_paddings = self.specaug(outputs, output_paddings)
# Subsample input by a factor of 4 by performing strided convolutions.
- outputs, output_paddings = Subsample(
- config=config)(outputs, output_paddings, train,
- dropout_rate=dropout_rate)
+ outputs, output_paddings = Subsample(config=config)(
+ outputs, output_paddings, train, dropout_rate=dropout_rate
+ )
# Run the lstm layers.
for _ in range(config.num_lstm_layers):
@@ -502,19 +531,21 @@ def __call__(self, inputs, input_paddings, train, dropout_rate=DROPOUT_RATE):
for _ in range(config.num_ffn_layers):
if config.enable_residual_connections:
outputs = outputs + FeedForwardModule(config=self.config)(
- outputs, output_paddings, train)
+ outputs, output_paddings, train
+ )
else:
outputs = FeedForwardModule(config=self.config)(
- outputs, output_paddings, train, dropout_rate=dropout_rate)
+ outputs, output_paddings, train, dropout_rate=dropout_rate
+ )
# Run the decoder which in this case is a trivial projection layer.
if config.enable_decoder_layer_norm:
outputs = LayerNorm(config.encoder_dim)(outputs)
outputs = nn.Dense(
- config.vocab_size,
- use_bias=True,
- kernel_init=nn.initializers.xavier_uniform())(
- outputs)
+ config.vocab_size,
+ use_bias=True,
+ kernel_init=nn.initializers.xavier_uniform(),
+ )(outputs)
return outputs, output_paddings
diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py
index 81a56db72..b93934abf 100644
--- a/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py
+++ b/algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py
@@ -1,31 +1,29 @@
import functools
from typing import Dict, Optional, Tuple
-from flax import jax_utils
import jax
import jax.numpy as jnp
import numpy as np
+from flax import jax_utils
-from algoperf import param_utils
-from algoperf import spec
-from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import \
- LibriSpeechConformerWorkload
+from algoperf import param_utils, spec
+from algoperf.workloads.librispeech_conformer.librispeech_jax.workload import (
+ LibriSpeechConformerWorkload,
+)
from algoperf.workloads.librispeech_deepspeech.librispeech_jax import models
class LibriSpeechDeepSpeechWorkload(LibriSpeechConformerWorkload):
-
def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
- """Deepspeech model init function.
- """
+ """Deepspeech model init function."""
model_config = models.DeepspeechConfig(
- use_specaug=self.use_specaug,
- use_tanh=self.use_tanh,
- enable_residual_connections=self.enable_residual_connections,
- enable_decoder_layer_norm=self.enable_decoder_layer_norm,
- layernorm_everywhere=self.layernorm_everywhere,
- freq_mask_count=self.freq_mask_count,
- time_mask_count=self.time_mask_count,
+ use_specaug=self.use_specaug,
+ use_tanh=self.use_tanh,
+ enable_residual_connections=self.enable_residual_connections,
+ enable_decoder_layer_norm=self.enable_decoder_layer_norm,
+ layernorm_everywhere=self.layernorm_everywhere,
+ freq_mask_count=self.freq_mask_count,
+ time_mask_count=self.time_mask_count,
)
self._model = models.Deepspeech(model_config)
input_shape = [(320000,), (320000,)]
@@ -34,12 +32,16 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
model_init_fn = jax.jit(functools.partial(self._model.init, train=False))
params_rng, _ = jax.random.split(rng, 2)
- variables = model_init_fn({
+ variables = model_init_fn(
+ {
'params': params_rng,
- }, *fake_input_batch)
+ },
+ *fake_input_batch,
+ )
- model_state = variables[
- 'batch_stats'] if not self.layernorm_everywhere else {}
+ model_state = (
+ variables['batch_stats'] if not self.layernorm_everywhere else {}
+ )
params = variables['params']
self._param_shapes = param_utils.jax_param_shapes(params)
self._param_types = param_utils.jax_param_types(self._param_shapes)
@@ -48,36 +50,34 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
return params, model_state
def model_fn(
- self,
- params: spec.ParameterContainer,
- augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- mode: spec.ForwardPassMode,
- rng: spec.RandomState,
- update_batch_norm: bool,
- use_running_average_bn: Optional[bool] = None,
- dropout_rate: Optional[bool] = models.DROPOUT_RATE,
+ self,
+ params: spec.ParameterContainer,
+ augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ mode: spec.ForwardPassMode,
+ rng: spec.RandomState,
+ update_batch_norm: bool,
+ use_running_average_bn: Optional[bool] = None,
+ dropout_rate: Optional[bool] = models.DROPOUT_RATE,
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
variables = {'params': params, **model_state}
inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs']
is_train_mode = mode == spec.ForwardPassMode.TRAIN
if update_batch_norm or is_train_mode:
(logits, logit_paddings), new_model_state = self._model.apply(
- variables,
- inputs,
- input_paddings,
- train=True,
- rngs={'dropout' : rng},
- mutable=['batch_stats'],
- dropout_rate=dropout_rate)
+ variables,
+ inputs,
+ input_paddings,
+ train=True,
+ rngs={'dropout': rng},
+ mutable=['batch_stats'],
+ dropout_rate=dropout_rate,
+ )
return (logits, logit_paddings), new_model_state
else:
logits, logit_paddings = self._model.apply(
- variables,
- inputs,
- input_paddings,
- train=False,
- mutable=False)
+ variables, inputs, input_paddings, train=False, mutable=False
+ )
return (logits, logit_paddings), model_state
def is_output_params(self, param_key: spec.ParameterKey) -> bool:
@@ -126,7 +126,6 @@ def time_mask_count(self) -> int:
class LibriSpeechDeepSpeechTanhWorkload(LibriSpeechDeepSpeechWorkload):
-
@property
def use_tanh(self) -> bool:
return True
@@ -141,7 +140,6 @@ def test_target_value(self) -> float:
class LibriSpeechDeepSpeechNoResNetWorkload(LibriSpeechDeepSpeechWorkload):
-
@property
def enable_residual_connections(self) -> bool:
return False
@@ -155,9 +153,9 @@ def test_target_value(self) -> float:
return 0.079297
-class LibriSpeechDeepSpeechNormAndSpecAugWorkload(LibriSpeechDeepSpeechWorkload
- ):
-
+class LibriSpeechDeepSpeechNormAndSpecAugWorkload(
+ LibriSpeechDeepSpeechWorkload
+):
@property
def eval_batch_size(self) -> int:
return 128
diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py
index 3d8c000e1..ddb7b5c37 100644
--- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py
+++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py
@@ -11,10 +11,12 @@
import torch.distributed.nn as dist_nn
import torch.nn.functional as F
-from algoperf.workloads.librispeech_conformer.librispeech_pytorch import \
- preprocessor
-from algoperf.workloads.librispeech_conformer.librispeech_pytorch.spectrum_augmenter import \
- SpecAug
+from algoperf.workloads.librispeech_conformer.librispeech_pytorch import (
+ preprocessor,
+)
+from algoperf.workloads.librispeech_conformer.librispeech_pytorch.spectrum_augmenter import (
+ SpecAug,
+)
USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ
DROPOUT_RATE = 0.1
@@ -23,6 +25,7 @@
@dataclass
class DeepspeechConfig:
"""Global hyperparameters used to minimize obnoxious kwarg plumbing."""
+
vocab_size: int = 1024
encoder_dim: int = 512
num_lstm_layers: int = 6
@@ -47,7 +50,6 @@ class DeepspeechConfig:
class LayerNorm(nn.Module):
-
def __init__(self, dim, epsilon=1e-6):
super().__init__()
self.dim = dim
@@ -61,14 +63,13 @@ def forward(self, x):
var = x.var(dim=-1, unbiased=False, keepdims=True)
normed_x = (x - mean) * torch.rsqrt(var + self.epsilon)
- normed_x *= (1 + self.scale)
+ normed_x *= 1 + self.scale
normed_x += self.bias
return normed_x
class Subsample(nn.Module):
-
def __init__(self, config: DeepspeechConfig):
super().__init__()
encoder_dim = config.encoder_dim
@@ -76,11 +77,13 @@ def __init__(self, config: DeepspeechConfig):
self.encoder_dim = encoder_dim
self.conv1 = Conv2dSubsampling(
- input_channels=1, output_channels=encoder_dim, use_tanh=config.use_tanh)
+ input_channels=1, output_channels=encoder_dim, use_tanh=config.use_tanh
+ )
self.conv2 = Conv2dSubsampling(
- input_channels=encoder_dim,
- output_channels=encoder_dim,
- use_tanh=config.use_tanh)
+ input_channels=encoder_dim,
+ output_channels=encoder_dim,
+ use_tanh=config.use_tanh,
+ )
self.lin = nn.LazyLinear(out_features=self.encoder_dim, bias=True)
@@ -93,9 +96,9 @@ def forward(self, inputs, input_paddings, dropout_rate):
outputs, output_paddings = self.conv2(outputs, output_paddings)
batch_size, channels, subsampled_lengths, subsampled_dims = outputs.shape
- outputs = outputs.permute(0, 2, 3, 1).reshape(batch_size,
- subsampled_lengths,
- subsampled_dims * channels)
+ outputs = outputs.permute(0, 2, 3, 1).reshape(
+ batch_size, subsampled_lengths, subsampled_dims * channels
+ )
outputs = self.lin(outputs)
outputs = F.dropout(outputs, dropout_rate, training=self.training)
@@ -104,15 +107,16 @@ def forward(self, inputs, input_paddings, dropout_rate):
class Conv2dSubsampling(nn.Module):
-
- def __init__(self,
- input_channels: int,
- output_channels: int,
- filter_stride: Tuple[int] = (2, 2),
- padding: str = 'SAME',
- batch_norm_momentum: float = 0.999,
- batch_norm_epsilon: float = 0.001,
- use_tanh: bool = False):
+ def __init__(
+ self,
+ input_channels: int,
+ output_channels: int,
+ filter_stride: Tuple[int] = (2, 2),
+ padding: str = 'SAME',
+ batch_norm_momentum: float = 0.999,
+ batch_norm_epsilon: float = 0.001,
+ use_tanh: bool = False,
+ ):
super().__init__()
self.input_channels = input_channels
@@ -123,7 +127,8 @@ def __init__(self,
self.filter_shape = (output_channels, input_channels, 3, 3)
self.kernel = nn.Parameter(
- nn.init.xavier_uniform_(torch.empty(*self.filter_shape)))
+ nn.init.xavier_uniform_(torch.empty(*self.filter_shape))
+ )
self.bias = nn.Parameter(torch.zeros(output_channels))
self.use_tanh = use_tanh
@@ -154,12 +159,13 @@ def forward(self, inputs, paddings):
else:
in_ = inputs
outputs = F.conv2d(
- input=in_,
- weight=self.kernel,
- bias=self.bias,
- stride=self.filter_stride,
- dilation=(1, 1),
- groups=groups)
+ input=in_,
+ weight=self.kernel,
+ bias=self.bias,
+ stride=self.filter_stride,
+ dilation=(1, 1),
+ groups=groups,
+ )
if self.use_tanh:
outputs = F.tanh(outputs)
@@ -170,21 +176,24 @@ def forward(self, inputs, paddings):
stride = self.filter_stride[0]
pad_len = (input_length + stride - 1) // stride * stride - input_length
out_padding = F.conv1d(
- input=torch.cat([
- paddings[:, None, :],
- torch.zeros(
- size=(paddings.shape[0], 1, pad_len), device=paddings.device)
+ input=torch.cat(
+ [
+ paddings[:, None, :],
+ torch.zeros(
+ size=(paddings.shape[0], 1, pad_len), device=paddings.device
+ ),
],
- dim=2),
- weight=torch.ones([1, 1, 1], device=paddings.device),
- stride=self.filter_stride[:1])
+ dim=2,
+ ),
+ weight=torch.ones([1, 1, 1], device=paddings.device),
+ stride=self.filter_stride[:1],
+ )
out_padding = out_padding.squeeze(dim=1)
outputs = outputs * (1 - out_padding[:, None, :, None])
return outputs, out_padding
class FeedForwardModule(nn.Module):
-
def __init__(self, config: DeepspeechConfig):
super().__init__()
self.config = config
@@ -193,9 +202,10 @@ def __init__(self, config: DeepspeechConfig):
self.normalization_layer = LayerNorm(config.encoder_dim)
else:
self.bn_normalization_layer = BatchNorm(
- dim=config.encoder_dim,
- batch_norm_momentum=config.batch_norm_momentum,
- batch_norm_epsilon=config.batch_norm_epsilon)
+ dim=config.encoder_dim,
+ batch_norm_momentum=config.batch_norm_momentum,
+ batch_norm_epsilon=config.batch_norm_epsilon,
+ )
self.lin = nn.LazyLinear(out_features=config.encoder_dim, bias=True)
def forward(self, inputs, input_paddings, dropout_rate):
@@ -220,7 +230,6 @@ def forward(self, inputs, input_paddings, dropout_rate):
class BatchNorm(nn.Module):
-
def __init__(self, dim, batch_norm_momentum, batch_norm_epsilon):
super().__init__()
running_mean = torch.zeros(dim)
@@ -235,8 +244,8 @@ def __init__(self, dim, batch_norm_momentum, batch_norm_epsilon):
self.dim = dim
def forward(self, inputs, input_paddings):
- #inputs: NHD
- #padding: NH
+ # inputs: NHD
+ # padding: NH
mask = 1 - input_paddings[:, :, None]
if self.training:
count = mask.sum()
@@ -253,9 +262,11 @@ def forward(self, inputs, input_paddings):
var = sum_ / count
self.running_mean = (1 - self.momentum) * self.running_mean + (
- self.momentum) * mean.detach()
+ self.momentum
+ ) * mean.detach()
self.running_var = (1 - self.momentum) * self.running_var + (
- self.momentum) * var.detach()
+ self.momentum
+ ) * var.detach()
else:
mean = self.running_mean
var = self.running_var
@@ -266,7 +277,6 @@ def forward(self, inputs, input_paddings):
class BatchRNN(nn.Module):
-
def __init__(self, config: DeepspeechConfig):
super().__init__()
self.config = config
@@ -278,19 +288,23 @@ def __init__(self, config: DeepspeechConfig):
if config.layernorm_everywhere:
self.normalization_layer = LayerNorm(config.encoder_dim)
else:
- self.bn_normalization_layer = BatchNorm(config.encoder_dim,
- config.batch_norm_momentum,
- config.batch_norm_epsilon)
+ self.bn_normalization_layer = BatchNorm(
+ config.encoder_dim,
+ config.batch_norm_momentum,
+ config.batch_norm_epsilon,
+ )
if bidirectional:
self.lstm = nn.LSTM(
- input_size=input_size,
- hidden_size=hidden_size // 2,
- bidirectional=True,
- batch_first=True)
+ input_size=input_size,
+ hidden_size=hidden_size // 2,
+ bidirectional=True,
+ batch_first=True,
+ )
else:
self.lstm = nn.LSTM(
- input_size=input_size, hidden_size=hidden_size, batch_first=True)
+ input_size=input_size, hidden_size=hidden_size, batch_first=True
+ )
def forward(self, inputs, input_paddings):
if self.config.layernorm_everywhere:
@@ -299,50 +313,59 @@ def forward(self, inputs, input_paddings):
inputs = self.bn_normalization_layer(inputs, input_paddings)
lengths = torch.sum(1 - input_paddings, dim=1).detach().cpu().numpy()
packed_inputs = torch.nn.utils.rnn.pack_padded_sequence(
- inputs, lengths, batch_first=True, enforce_sorted=False)
+ inputs, lengths, batch_first=True, enforce_sorted=False
+ )
packed_outputs, _ = self.lstm(packed_inputs)
outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(
- packed_outputs, batch_first=True)
+ packed_outputs, batch_first=True
+ )
if outputs.shape[1] < inputs.shape[1]:
- outputs = torch.cat([
+ outputs = torch.cat(
+ [
outputs,
torch.zeros(
- size=(outputs.shape[0],
- inputs.shape[1] - outputs.shape[1],
- outputs.shape[2]),
- device=outputs.device)
- ],
- dim=1)
+ size=(
+ outputs.shape[0],
+ inputs.shape[1] - outputs.shape[1],
+ outputs.shape[2],
+ ),
+ device=outputs.device,
+ ),
+ ],
+ dim=1,
+ )
return outputs
class DeepspeechEncoderDecoder(nn.Module):
-
def __init__(self, config: DeepspeechConfig):
super().__init__()
self.config = config
self.specaug = SpecAug(
- freq_mask_count=config.freq_mask_count,
- freq_mask_max_bins=config.freq_mask_max_bins,
- time_mask_count=config.time_mask_count,
- time_mask_max_frames=config.time_mask_max_frames,
- time_mask_max_ratio=config.time_mask_max_ratio,
- time_masks_per_frame=config.time_masks_per_frame,
- use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames
+ freq_mask_count=config.freq_mask_count,
+ freq_mask_max_bins=config.freq_mask_max_bins,
+ time_mask_count=config.time_mask_count,
+ time_mask_max_frames=config.time_mask_max_frames,
+ time_mask_max_ratio=config.time_mask_max_ratio,
+ time_masks_per_frame=config.time_masks_per_frame,
+ use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames,
)
preprocessing_config = preprocessor.PreprocessorConfig()
self.preprocessor = preprocessor.MelFilterbankFrontend(
- preprocessing_config,
- per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR,
- per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR)
+ preprocessing_config,
+ per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR,
+ per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR,
+ )
self.subsample = Subsample(config=config)
self.lstms = nn.ModuleList(
- [BatchRNN(config) for _ in range(config.num_lstm_layers)])
+ [BatchRNN(config) for _ in range(config.num_lstm_layers)]
+ )
self.ffns = nn.ModuleList(
- [FeedForwardModule(config) for _ in range(config.num_ffn_layers)])
+ [FeedForwardModule(config) for _ in range(config.num_ffn_layers)]
+ )
if config.enable_decoder_layer_norm:
self.ln = LayerNorm(config.encoder_dim)
diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py
index 0b9ce1e3c..672f3440f 100644
--- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py
+++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py
@@ -3,18 +3,19 @@
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
-from algoperf import param_utils
-from algoperf import spec
+from algoperf import param_utils, spec
from algoperf.pytorch_utils import pytorch_setup
from algoperf.workloads.librispeech_conformer.librispeech_pytorch import models
-from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import \
- initialize
-from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import \
- LibriSpeechConformerWorkload
-from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import \
- DeepspeechConfig
-from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import \
- DeepspeechEncoderDecoder
+from algoperf.workloads.librispeech_conformer.librispeech_pytorch.models import (
+ initialize,
+)
+from algoperf.workloads.librispeech_conformer.librispeech_pytorch.workload import (
+ LibriSpeechConformerWorkload,
+)
+from algoperf.workloads.librispeech_deepspeech.librispeech_pytorch.models import (
+ DeepspeechConfig,
+ DeepspeechEncoderDecoder,
+)
USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup()
@@ -22,19 +23,20 @@
class LibriSpeechDeepSpeechWorkload(LibriSpeechConformerWorkload):
-
def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
"""Deepspeech model init function."""
torch.random.manual_seed(rng[0])
model = DeepspeechEncoderDecoder(
- DeepspeechConfig(
- use_specaug=self.use_specaug,
- use_tanh=self.use_tanh,
- enable_residual_connections=self.enable_residual_connections,
- enable_decoder_layer_norm=self.enable_decoder_layer_norm,
- layernorm_everywhere=self.layernorm_everywhere,
- freq_mask_count=self.freq_mask_count,
- time_mask_count=self.time_mask_count)).eval()
+ DeepspeechConfig(
+ use_specaug=self.use_specaug,
+ use_tanh=self.use_tanh,
+ enable_residual_connections=self.enable_residual_connections,
+ enable_decoder_layer_norm=self.enable_decoder_layer_norm,
+ layernorm_everywhere=self.layernorm_everywhere,
+ freq_mask_count=self.freq_mask_count,
+ time_mask_count=self.time_mask_count,
+ )
+ ).eval()
self.ctc_loss = torch.nn.CTCLoss(blank=0, reduction='none')
# Run model once to initialize lazy layers.
t = MAX_INPUT_LENGTH
@@ -55,24 +57,26 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
return model, None
def model_fn(
- self,
- params: spec.ParameterContainer,
- augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- mode: spec.ForwardPassMode,
- rng: spec.RandomState,
- update_batch_norm: bool,
- dropout_rate: float = models.DROPOUT_RATE
+ self,
+ params: spec.ParameterContainer,
+ augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ mode: spec.ForwardPassMode,
+ rng: spec.RandomState,
+ update_batch_norm: bool,
+ dropout_rate: float = models.DROPOUT_RATE,
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
# override super method, changing only the default dropout_rate
# pylint: disable=useless-parent-delegation
- return super().model_fn(params,
- augmented_and_preprocessed_input_batch,
- model_state,
- mode,
- rng,
- update_batch_norm,
- dropout_rate)
+ return super().model_fn(
+ params,
+ augmented_and_preprocessed_input_batch,
+ model_state,
+ mode,
+ rng,
+ update_batch_norm,
+ dropout_rate,
+ )
def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key in ['lin.weight', 'lin.bias']
@@ -120,7 +124,6 @@ def time_mask_count(self) -> int:
class LibriSpeechDeepSpeechTanhWorkload(LibriSpeechDeepSpeechWorkload):
-
@property
def use_tanh(self) -> bool:
return True
@@ -135,7 +138,6 @@ def test_target_value(self) -> float:
class LibriSpeechDeepSpeechNoResNetWorkload(LibriSpeechDeepSpeechWorkload):
-
@property
def enable_residual_connections(self) -> bool:
return False
@@ -149,9 +151,9 @@ def test_target_value(self) -> float:
return 0.079297
-class LibriSpeechDeepSpeechNormAndSpecAugWorkload(LibriSpeechDeepSpeechWorkload
- ):
-
+class LibriSpeechDeepSpeechNormAndSpecAugWorkload(
+ LibriSpeechDeepSpeechWorkload
+):
@property
def eval_batch_size(self) -> int:
return 128
diff --git a/algoperf/workloads/mnist/mnist_jax/workload.py b/algoperf/workloads/mnist/mnist_jax/workload.py
index 27bd9ae54..0d192cbf5 100644
--- a/algoperf/workloads/mnist/mnist_jax/workload.py
+++ b/algoperf/workloads/mnist/mnist_jax/workload.py
@@ -3,20 +3,18 @@
import functools
from typing import Any, Dict, Optional, Tuple
-from flax import jax_utils
-from flax import linen as nn
import jax
-from jax import lax
import jax.numpy as jnp
import optax
+from flax import jax_utils
+from flax import linen as nn
+from jax import lax
-from algoperf import param_utils
-from algoperf import spec
+from algoperf import param_utils, spec
from algoperf.workloads.mnist.workload import BaseMnistWorkload
class _Model(nn.Module):
-
@nn.compact
def __call__(self, x: spec.Tensor, train: bool) -> spec.Tensor:
del train
@@ -31,12 +29,12 @@ def __call__(self, x: spec.Tensor, train: bool) -> spec.Tensor:
class MnistWorkload(BaseMnistWorkload):
-
def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
init_val = jnp.ones((1, 28, 28, 1), jnp.float32)
self._model = _Model()
- initial_params = self._model.init({'params': rng}, init_val,
- train=True)['params']
+ initial_params = self._model.init({'params': rng}, init_val, train=True)[
+ 'params'
+ ]
self._param_shapes = param_utils.jax_param_shapes(initial_params)
self._param_types = param_utils.jax_param_types(self._param_shapes)
return jax_utils.replicate(initial_params), None
@@ -45,31 +43,34 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key == 'Dense_1'
def model_fn(
- self,
- params: spec.ParameterContainer,
- augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- mode: spec.ForwardPassMode,
- rng: spec.RandomState,
- update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ self,
+ params: spec.ParameterContainer,
+ augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ mode: spec.ForwardPassMode,
+ rng: spec.RandomState,
+ update_batch_norm: bool,
+ ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del rng
del update_batch_norm
train = mode == spec.ForwardPassMode.TRAIN
logits_batch = self._model.apply(
- {'params': params},
- augmented_and_preprocessed_input_batch['inputs'],
- train=train)
+ {'params': params},
+ augmented_and_preprocessed_input_batch['inputs'],
+ train=train,
+ )
return logits_batch, None
# Does NOT apply regularization, which is left to the submitter to do in
# `update_params`.
def loss_fn(
- self,
- label_batch: spec.Tensor, # Dense or one-hot labels.
- logits_batch: spec.Tensor,
- mask_batch: Optional[spec.Tensor] = None,
- label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable
+ self,
+ label_batch: spec.Tensor, # Dense or one-hot labels.
+ logits_batch: spec.Tensor,
+ mask_batch: Optional[spec.Tensor] = None,
+ label_smoothing: float = 0.0,
+ ) -> Dict[str, spec.Tensor]: # differentiable
"""Evaluate the (masked) loss function at (label_batch, logits_batch).
Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of
@@ -79,7 +80,8 @@ def loss_fn(
one_hot_targets = jax.nn.one_hot(label_batch, 10)
smoothed_targets = optax.smooth_labels(one_hot_targets, label_smoothing)
per_example_losses = -jnp.sum(
- smoothed_targets * nn.log_softmax(logits_batch), axis=-1)
+ smoothed_targets * nn.log_softmax(logits_batch), axis=-1
+ )
# `mask_batch` is assumed to be shape [batch].
if mask_batch is not None:
per_example_losses *= mask_batch
@@ -88,41 +90,45 @@ def loss_fn(
n_valid_examples = len(per_example_losses)
summed_loss = per_example_losses.sum()
return {
- 'summed': summed_loss,
- 'n_valid_examples': n_valid_examples,
- 'per_example': per_example_losses,
+ 'summed': summed_loss,
+ 'n_valid_examples': n_valid_examples,
+ 'per_example': per_example_losses,
}
@functools.partial(
- jax.pmap,
- axis_name='batch',
- in_axes=(None, 0, 0, 0, None),
- static_broadcasted_argnums=(0,))
+ jax.pmap,
+ axis_name='batch',
+ in_axes=(None, 0, 0, 0, None),
+ static_broadcasted_argnums=(0,),
+ )
def _eval_model(
- self,
- params: spec.ParameterContainer,
- batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- rng: spec.RandomState) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ self,
+ params: spec.ParameterContainer,
+ batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ rng: spec.RandomState,
+ ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
logits, _ = self.model_fn(
- params,
- batch,
- model_state,
- spec.ForwardPassMode.EVAL,
- rng,
- update_batch_norm=False)
+ params,
+ batch,
+ model_state,
+ spec.ForwardPassMode.EVAL,
+ rng,
+ update_batch_norm=False,
+ )
weights = batch.get('weights')
if weights is None:
weights = jnp.ones(len(logits))
accuracy = jnp.sum(
- (jnp.argmax(logits, axis=-1) == batch['targets']) * weights)
+ (jnp.argmax(logits, axis=-1) == batch['targets']) * weights
+ )
summed_loss = self.loss_fn(batch['targets'], logits, weights)['summed']
metrics = {'accuracy': accuracy, 'loss': summed_loss}
metrics = lax.psum(metrics, axis_name='batch')
return metrics
def _normalize_eval_metrics(
- self, num_examples: int, total_metrics: Dict[str,
- Any]) -> Dict[str, float]:
+ self, num_examples: int, total_metrics: Dict[str, Any]
+ ) -> Dict[str, float]:
"""Normalize eval metrics."""
return jax.tree.map(lambda x: float(x[0] / num_examples), total_metrics)
diff --git a/algoperf/workloads/mnist/mnist_pytorch/workload.py b/algoperf/workloads/mnist/mnist_pytorch/workload.py
index 780e1bca0..ca861d551 100644
--- a/algoperf/workloads/mnist/mnist_pytorch/workload.py
+++ b/algoperf/workloads/mnist/mnist_pytorch/workload.py
@@ -20,18 +20,20 @@
class _Model(nn.Module):
-
def __init__(self) -> None:
super().__init__()
input_size = 28 * 28
num_hidden = 128
num_classes = 10
self.net = nn.Sequential(
- OrderedDict([('layer1',
- torch.nn.Linear(input_size, num_hidden, bias=True)),
- ('layer1_sig', torch.nn.Sigmoid()),
- ('layer2',
- torch.nn.Linear(num_hidden, num_classes, bias=True))]))
+ OrderedDict(
+ [
+ ('layer1', torch.nn.Linear(input_size, num_hidden, bias=True)),
+ ('layer1_sig', torch.nn.Sigmoid()),
+ ('layer2', torch.nn.Linear(num_hidden, num_classes, bias=True)),
+ ]
+ )
+ )
def reset_parameters(self) -> None:
for m in self.net.modules():
@@ -44,16 +46,16 @@ def forward(self, x: spec.Tensor) -> spec.Tensor:
class MnistWorkload(BaseMnistWorkload):
-
def _build_input_queue(
- self,
- data_rng: spec.RandomState,
- split: str,
- data_dir: str,
- global_batch_size: int,
- cache: Optional[bool] = None,
- repeat_final_dataset: Optional[bool] = None,
- num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]:
+ self,
+ data_rng: spec.RandomState,
+ split: str,
+ data_dir: str,
+ global_batch_size: int,
+ cache: Optional[bool] = None,
+ repeat_final_dataset: Optional[bool] = None,
+ num_batches: Optional[int] = None,
+ ) -> Iterator[Dict[str, spec.Tensor]]:
del cache
if N_GPUS != 0:
per_device_batch_size = int(global_batch_size / N_GPUS)
@@ -63,22 +65,27 @@ def _build_input_queue(
# Only create and iterate over tf input pipeline in one Python process to
# avoid creating too many threads.
if RANK == 0:
- np_iter = super()._build_input_queue(data_rng,
- split,
- data_dir,
- global_batch_size,
- num_batches,
- repeat_final_dataset)
+ np_iter = super()._build_input_queue(
+ data_rng,
+ split,
+ data_dir,
+ global_batch_size,
+ num_batches,
+ repeat_final_dataset,
+ )
while True:
if RANK == 0:
batch = next(np_iter) # pylint: disable=stop-iteration-return
inputs = torch.as_tensor(
- batch['inputs'], dtype=torch.float32, device=DEVICE)
+ batch['inputs'], dtype=torch.float32, device=DEVICE
+ )
targets = torch.as_tensor(
- batch['targets'], dtype=torch.long, device=DEVICE)
+ batch['targets'], dtype=torch.long, device=DEVICE
+ )
if 'weights' in batch:
weights = torch.as_tensor(
- batch['weights'], dtype=torch.bool, device=DEVICE)
+ batch['weights'], dtype=torch.bool, device=DEVICE
+ )
else:
weights = torch.ones_like(targets, dtype=torch.bool, device=DEVICE)
# Send batch to other devices when using DDP.
@@ -94,34 +101,37 @@ def _build_input_queue(
targets = targets.view(-1, *targets.shape[2:])
weights = weights.view(-1, *weights.shape[2:])
else:
- inputs = torch.empty((N_GPUS, per_device_batch_size, 28, 28, 1),
- dtype=torch.float32,
- device=DEVICE)
+ inputs = torch.empty(
+ (N_GPUS, per_device_batch_size, 28, 28, 1),
+ dtype=torch.float32,
+ device=DEVICE,
+ )
dist.broadcast(inputs, src=0)
inputs = inputs[RANK]
- targets = torch.empty((N_GPUS, per_device_batch_size),
- dtype=torch.long,
- device=DEVICE)
+ targets = torch.empty(
+ (N_GPUS, per_device_batch_size), dtype=torch.long, device=DEVICE
+ )
dist.broadcast(targets, src=0)
targets = targets[RANK]
- weights = torch.empty((N_GPUS, per_device_batch_size),
- dtype=torch.bool,
- device=DEVICE)
+ weights = torch.empty(
+ (N_GPUS, per_device_batch_size), dtype=torch.bool, device=DEVICE
+ )
dist.broadcast(weights, src=0)
weights = weights[RANK]
batch = {
- 'inputs': inputs.permute(0, 3, 1, 2),
- 'targets': targets,
- 'weights': weights,
+ 'inputs': inputs.permute(0, 3, 1, 2),
+ 'targets': targets,
+ 'weights': weights,
}
yield batch
def init_model_fn(
- self,
- rng: spec.RandomState,
- dropout_rate: Optional[float] = None,
- aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
+ self,
+ rng: spec.RandomState,
+ dropout_rate: Optional[float] = None,
+ aux_dropout_rate: Optional[float] = None,
+ ) -> spec.ModelInitState:
"""Dropout is unused."""
del dropout_rate
del aux_dropout_rate
@@ -149,13 +159,14 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key in ['net.layer2.weight', 'net_layer2.bias']
def model_fn(
- self,
- params: spec.ParameterContainer,
- augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- mode: spec.ForwardPassMode,
- rng: spec.RandomState,
- update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
+ self,
+ params: spec.ParameterContainer,
+ augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ mode: spec.ForwardPassMode,
+ rng: spec.RandomState,
+ update_batch_norm: bool,
+ ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del rng
del update_batch_norm
@@ -163,8 +174,8 @@ def model_fn(
if mode == spec.ForwardPassMode.EVAL:
model.eval()
contexts = {
- spec.ForwardPassMode.EVAL: torch.no_grad,
- spec.ForwardPassMode.TRAIN: contextlib.nullcontext,
+ spec.ForwardPassMode.EVAL: torch.no_grad,
+ spec.ForwardPassMode.TRAIN: contextlib.nullcontext,
}
with contexts[mode]():
logits_batch = model(augmented_and_preprocessed_input_batch['inputs'])
@@ -173,11 +184,12 @@ def model_fn(
# Does NOT apply regularization, which is left to the submitter to do in
# `update_params`.
def loss_fn(
- self,
- label_batch: spec.Tensor, # Dense or one-hot labels.
- logits_batch: spec.Tensor,
- mask_batch: Optional[spec.Tensor] = None,
- label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable
+ self,
+ label_batch: spec.Tensor, # Dense or one-hot labels.
+ logits_batch: spec.Tensor,
+ mask_batch: Optional[spec.Tensor] = None,
+ label_smoothing: float = 0.0,
+ ) -> Dict[str, spec.Tensor]: # differentiable
"""Evaluate the (masked) loss function at (label_batch, logits_batch).
Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of
@@ -185,10 +197,11 @@ def loss_fn(
(not synced across devices).
"""
per_example_losses = F.cross_entropy(
- logits_batch,
- label_batch,
- reduction='none',
- label_smoothing=label_smoothing)
+ logits_batch,
+ label_batch,
+ reduction='none',
+ label_smoothing=label_smoothing,
+ )
# `mask_batch` is assumed to be shape [batch].
if mask_batch is not None:
per_example_losses *= mask_batch
@@ -197,25 +210,27 @@ def loss_fn(
n_valid_examples = len(per_example_losses)
summed_loss = per_example_losses.sum()
return {
- 'summed': summed_loss,
- 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE),
- 'per_example': per_example_losses,
+ 'summed': summed_loss,
+ 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE),
+ 'per_example': per_example_losses,
}
def _eval_model(
- self,
- params: spec.ParameterContainer,
- batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- rng: spec.RandomState) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]:
+ self,
+ params: spec.ParameterContainer,
+ batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ rng: spec.RandomState,
+ ) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]:
"""Return the mean accuracy and loss as a dict."""
logits, _ = self.model_fn(
- params,
- batch,
- model_state,
- spec.ForwardPassMode.EVAL,
- rng,
- update_batch_norm=False)
+ params,
+ batch,
+ model_state,
+ spec.ForwardPassMode.EVAL,
+ rng,
+ update_batch_norm=False,
+ )
targets = batch['targets']
weights = batch.get('weights')
if weights is None:
@@ -227,8 +242,8 @@ def _eval_model(
return {'accuracy': accuracy, 'loss': summed_loss}
def _normalize_eval_metrics(
- self, num_examples: int, total_metrics: Dict[str,
- Any]) -> Dict[str, float]:
+ self, num_examples: int, total_metrics: Dict[str, Any]
+ ) -> Dict[str, float]:
if USE_PYTORCH_DDP:
for metric in total_metrics.values():
dist.all_reduce(metric)
diff --git a/algoperf/workloads/mnist/workload.py b/algoperf/workloads/mnist/workload.py
index f53aadd0b..20aa975ae 100644
--- a/algoperf/workloads/mnist/workload.py
+++ b/algoperf/workloads/mnist/workload.py
@@ -23,16 +23,17 @@ def _normalize(image: spec.Tensor, mean: float, stddev: float) -> spec.Tensor:
def _build_mnist_dataset(
- data_rng: jax.random.PRNGKey,
- num_train_examples: int,
- num_validation_examples: int,
- train_mean: float,
- train_stddev: float,
- split: str,
- data_dir: str,
- global_batch_size: int,
- cache: bool = False,
- repeat_final_dataset: bool = True) -> Iterator[Dict[str, spec.Tensor]]:
+ data_rng: jax.random.PRNGKey,
+ num_train_examples: int,
+ num_validation_examples: int,
+ train_mean: float,
+ train_stddev: float,
+ split: str,
+ data_dir: str,
+ global_batch_size: int,
+ cache: bool = False,
+ repeat_final_dataset: bool = True,
+) -> Iterator[Dict[str, spec.Tensor]]:
shuffle = split in ['train', 'eval_train']
assert num_train_examples + num_validation_examples == 60000
if shuffle:
@@ -42,12 +43,14 @@ def _build_mnist_dataset(
else:
tfds_split = 'test'
ds = tfds.load(
- 'mnist', split=tfds_split, shuffle_files=False, data_dir=data_dir)
+ 'mnist', split=tfds_split, shuffle_files=False, data_dir=data_dir
+ )
ds = ds.map(
- lambda x: {
- 'inputs': _normalize(x['image'], train_mean, train_stddev),
- 'targets': x['label'],
- })
+ lambda x: {
+ 'inputs': _normalize(x['image'], train_mean, train_stddev),
+ 'targets': x['label'],
+ }
+ )
is_train = split == 'train'
if cache:
@@ -62,22 +65,23 @@ def _build_mnist_dataset(
ds = ds.repeat()
ds = map(
- functools.partial(
- data_utils.shard_and_maybe_pad_np,
- global_batch_size=global_batch_size),
- ds)
+ functools.partial(
+ data_utils.shard_and_maybe_pad_np, global_batch_size=global_batch_size
+ ),
+ ds,
+ )
return iter(ds)
class BaseMnistWorkload(spec.Workload):
-
@property
def target_metric_name(self) -> str:
"""The name of the target metric (useful for scoring/processing code)."""
return 'accuracy'
- def has_reached_validation_target(self, eval_result: Dict[str,
- float]) -> bool:
+ def has_reached_validation_target(
+ self, eval_result: Dict[str, float]
+ ) -> bool:
return eval_result['validation/accuracy'] > self.validation_target_value
@property
@@ -104,8 +108,9 @@ def num_eval_train_examples(self) -> int:
# Round up from num_validation_examples (which is the default for
# num_eval_train_examples) to the next multiple of eval_batch_size, so that
# we don't have to extract the correctly sized subset of the training data.
- rounded_up_multiple = math.ceil(self.num_validation_examples /
- self.eval_batch_size)
+ rounded_up_multiple = math.ceil(
+ self.num_validation_examples / self.eval_batch_size
+ )
return rounded_up_multiple * self.eval_batch_size
@property
@@ -138,31 +143,33 @@ def eval_period_time_sec(self) -> int:
@abc.abstractmethod
def _normalize_eval_metrics(
- self, num_examples: int, total_metrics: Dict[str,
- Any]) -> Dict[str, float]:
+ self, num_examples: int, total_metrics: Dict[str, Any]
+ ) -> Dict[str, float]:
"""Normalize eval metrics."""
def _build_input_queue(
- self,
- data_rng: spec.RandomState,
- split: str,
- data_dir: str,
- global_batch_size: int,
- cache: Optional[bool] = None,
- repeat_final_dataset: Optional[bool] = None,
- num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]:
+ self,
+ data_rng: spec.RandomState,
+ split: str,
+ data_dir: str,
+ global_batch_size: int,
+ cache: Optional[bool] = None,
+ repeat_final_dataset: Optional[bool] = None,
+ num_batches: Optional[int] = None,
+ ) -> Iterator[Dict[str, spec.Tensor]]:
del num_batches
ds = _build_mnist_dataset(
- data_rng=data_rng,
- num_train_examples=self.num_train_examples,
- num_validation_examples=self.num_validation_examples,
- train_mean=self.train_mean,
- train_stddev=self.train_stddev,
- split=split,
- data_dir=data_dir,
- global_batch_size=global_batch_size,
- cache=cache,
- repeat_final_dataset=repeat_final_dataset)
+ data_rng=data_rng,
+ num_train_examples=self.num_train_examples,
+ num_validation_examples=self.num_validation_examples,
+ train_mean=self.train_mean,
+ train_stddev=self.train_stddev,
+ split=split,
+ data_dir=data_dir,
+ global_batch_size=global_batch_size,
+ cache=cache,
+ repeat_final_dataset=repeat_final_dataset,
+ )
return ds
@property
@@ -173,49 +180,52 @@ def step_hint(self) -> int:
return 7813
def _eval_model(
- self,
- params: spec.ParameterContainer,
- batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- rng: spec.RandomState) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]:
+ self,
+ params: spec.ParameterContainer,
+ batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ rng: spec.RandomState,
+ ) -> Dict[spec.Tensor, spec.ModelAuxiliaryState]:
raise NotImplementedError
- def _eval_model_on_split(self,
- split: str,
- num_examples: int,
- global_batch_size: int,
- params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- rng: spec.RandomState,
- data_dir: str,
- global_step: int = 0) -> Dict[str, float]:
+ def _eval_model_on_split(
+ self,
+ split: str,
+ num_examples: int,
+ global_batch_size: int,
+ params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ rng: spec.RandomState,
+ data_dir: str,
+ global_step: int = 0,
+ ) -> Dict[str, float]:
"""Run a full evaluation of the model."""
del global_step
data_rng, model_rng = prng.split(rng, 2)
if split not in self._eval_iters:
self._eval_iters[split] = self._build_input_queue(
- data_rng=data_rng,
- split=split,
- data_dir=data_dir,
- global_batch_size=global_batch_size,
- cache=True,
- repeat_final_dataset=True)
+ data_rng=data_rng,
+ split=split,
+ data_dir=data_dir,
+ global_batch_size=global_batch_size,
+ cache=True,
+ repeat_final_dataset=True,
+ )
total_metrics = {
- 'accuracy': 0.,
- 'loss': 0.,
+ 'accuracy': 0.0,
+ 'loss': 0.0,
}
num_batches = int(math.ceil(num_examples / global_batch_size))
num_devices = max(torch.cuda.device_count(), jax.local_device_count())
for _ in range(num_batches):
batch = next(self._eval_iters[split])
per_device_model_rngs = prng.split(model_rng, num_devices)
- batch_metrics = self._eval_model(params,
- batch,
- model_state,
- per_device_model_rngs)
+ batch_metrics = self._eval_model(
+ params, batch, model_state, per_device_model_rngs
+ )
total_metrics = {
- k: v + batch_metrics[k] for k, v in total_metrics.items()
+ k: v + batch_metrics[k] for k, v in total_metrics.items()
}
return self._normalize_eval_metrics(num_examples, total_metrics)
diff --git a/algoperf/workloads/ogbg/input_pipeline.py b/algoperf/workloads/ogbg/input_pipeline.py
index 3cb6f51de..79a4ddc4a 100644
--- a/algoperf/workloads/ogbg/input_pipeline.py
+++ b/algoperf/workloads/ogbg/input_pipeline.py
@@ -14,10 +14,10 @@
AVG_EDGES_PER_GRAPH = 56
TFDS_SPLIT_NAME = {
- 'train': 'train',
- 'eval_train': 'train',
- 'validation': 'validation',
- 'test': 'test',
+ 'train': 'train',
+ 'eval_train': 'train',
+ 'validation': 'validation',
+ 'test': 'test',
}
@@ -33,11 +33,12 @@ def _load_dataset(split, should_shuffle, data_rng, data_dir):
read_config = tfds.ReadConfig(add_tfds_id=True, shuffle_seed=file_data_rng)
dataset = tfds.load(
- 'ogbg_molpcba:0.1.3',
- split=TFDS_SPLIT_NAME[split],
- shuffle_files=should_shuffle,
- read_config=read_config,
- data_dir=data_dir)
+ 'ogbg_molpcba:0.1.3',
+ split=TFDS_SPLIT_NAME[split],
+ shuffle_files=should_shuffle,
+ read_config=read_config,
+ data_dir=data_dir,
+ )
if should_shuffle:
dataset = dataset.shuffle(seed=dataset_data_rng, buffer_size=2**15)
@@ -62,16 +63,17 @@ def _to_jraph(example):
receivers = edge_index[:, 1]
return jraph.GraphsTuple(
- n_node=num_nodes,
- n_edge=np.array([len(edge_index) * 2]),
- nodes=node_feat,
- edges=np.concatenate([edge_feat, edge_feat]),
- # Make the edges bidirectional
- senders=np.concatenate([senders, receivers]),
- receivers=np.concatenate([receivers, senders]),
- # Keep the labels with the graph for batching. They will be removed
- # in the processed batch.
- globals=np.expand_dims(labels, axis=0))
+ n_node=num_nodes,
+ n_edge=np.array([len(edge_index) * 2]),
+ nodes=node_feat,
+ edges=np.concatenate([edge_feat, edge_feat]),
+ # Make the edges bidirectional
+ senders=np.concatenate([senders, receivers]),
+ receivers=np.concatenate([receivers, senders]),
+ # Keep the labels with the graph for batching. They will be removed
+ # in the processed batch.
+ globals=np.expand_dims(labels, axis=0),
+ )
def _get_weights_by_nan_and_padding(labels, padding_mask):
@@ -123,10 +125,9 @@ def _get_batch_iterator(dataset_iter, global_batch_size, num_shards=None):
max_n_graphs = per_device_batch_size
jraph_iter = map(_to_jraph, dataset_iter)
- batched_iter = jraph.dynamically_batch(jraph_iter,
- max_n_nodes + 1,
- max_n_edges,
- max_n_graphs + 1)
+ batched_iter = jraph.dynamically_batch(
+ jraph_iter, max_n_nodes + 1, max_n_edges, max_n_graphs + 1
+ )
count = 0
graphs_shards = []
@@ -141,7 +142,8 @@ def _get_batch_iterator(dataset_iter, global_batch_size, num_shards=None):
graph = batched_graph._replace(globals={})
replaced_labels, weights = _get_weights_by_nan_and_padding(
- labels, jraph.get_graph_padding_mask(graph))
+ labels, jraph.get_graph_padding_mask(graph)
+ )
graphs_shards.append(graph)
labels_shards.append(replaced_labels)
@@ -156,9 +158,9 @@ def f(x):
labels_shards = f(labels_shards)
weights_shards = f(weights_shards)
yield {
- 'inputs': graphs_shards,
- 'targets': labels_shards,
- 'weights': weights_shards,
+ 'inputs': graphs_shards,
+ 'targets': labels_shards,
+ 'weights': weights_shards,
}
count = 0
@@ -170,5 +172,6 @@ def f(x):
def get_dataset_iter(split, data_rng, data_dir, global_batch_size):
shuffle = split in ['train', 'eval_train']
ds = _load_dataset(
- split, should_shuffle=shuffle, data_rng=data_rng, data_dir=data_dir)
+ split, should_shuffle=shuffle, data_rng=data_rng, data_dir=data_dir
+ )
return _get_batch_iterator(iter(ds), global_batch_size)
diff --git a/algoperf/workloads/ogbg/metrics.py b/algoperf/workloads/ogbg/metrics.py
index 55f83d905..668501788 100644
--- a/algoperf/workloads/ogbg/metrics.py
+++ b/algoperf/workloads/ogbg/metrics.py
@@ -16,10 +16,9 @@
USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup()
-def predictions_match_labels(*,
- logits: jnp.ndarray,
- labels: jnp.ndarray,
- **kwargs) -> jnp.ndarray:
+def predictions_match_labels(
+ *, logits: jnp.ndarray, labels: jnp.ndarray, **kwargs
+) -> jnp.ndarray:
"""Returns a binary array indicating where predictions match the labels."""
del kwargs # Unused.
preds = logits > 0
@@ -28,7 +27,8 @@ def predictions_match_labels(*,
@flax.struct.dataclass
class MeanAveragePrecision(
- metrics.CollectingMetric.from_outputs(('logits', 'labels', 'mask'))):
+ metrics.CollectingMetric.from_outputs(('logits', 'labels', 'mask'))
+):
"""Computes the mean average precision (mAP) over different tasks."""
def compute(self):
@@ -62,7 +62,8 @@ def compute(self):
if np.sum(labels[:, task] == 0) > 0 and np.sum(labels[:, task] == 1) > 0:
is_labeled = mask[:, task]
average_precisions[task] = average_precision_score(
- labels[is_labeled, task], probs[is_labeled, task])
+ labels[is_labeled, task], probs[is_labeled, task]
+ )
# When all APs are NaNs, return NaN. This avoids raising a RuntimeWarning.
if np.isnan(average_precisions).all():
diff --git a/algoperf/workloads/ogbg/ogbg_jax/models.py b/algoperf/workloads/ogbg/ogbg_jax/models.py
index 8524bb60e..db1ca416c 100644
--- a/algoperf/workloads/ogbg/ogbg_jax/models.py
+++ b/algoperf/workloads/ogbg/ogbg_jax/models.py
@@ -2,9 +2,9 @@
# https://github.com/google/init2winit/blob/master/init2winit/model_lib/gnn.py.
from typing import Tuple
-from flax import linen as nn
import jax.numpy as jnp
import jraph
+from flax import linen as nn
from algoperf.jax_utils import Dropout
@@ -12,7 +12,6 @@
def _make_embed(latent_dim, name):
-
def make_fn(inputs):
return nn.Dense(features=latent_dim, name=name)(inputs)
@@ -29,9 +28,9 @@ def make_fn(inputs):
x = nn.Dense(features=dim)(x)
x = nn.LayerNorm()(x)
x = activation_fn(x)
- x = Dropout(
- rate=dropout_rate, deterministic=not train)(
- x, rate=dropout_rate)
+ x = Dropout(rate=dropout_rate, deterministic=not train)(
+ x, rate=dropout_rate
+ )
return x
return make_fn
@@ -42,6 +41,7 @@ class GNN(nn.Module):
The model assumes the input data is a jraph.GraphsTuple without global
variables. The final prediction will be encoded in the globals.
"""
+
num_outputs: int
latent_dim: int = 256
hidden_dims: Tuple[int] = (256,)
@@ -50,13 +50,14 @@ class GNN(nn.Module):
@nn.compact
def __call__(self, graph, train, dropout_rate=DROPOUT_RATE):
-
graph = graph._replace(
- globals=jnp.zeros([graph.n_node.shape[0], self.num_outputs]))
+ globals=jnp.zeros([graph.n_node.shape[0], self.num_outputs])
+ )
embedder = jraph.GraphMapFeatures(
- embed_node_fn=_make_embed(self.latent_dim, name='node_embedding'),
- embed_edge_fn=_make_embed(self.latent_dim, name='edge_embedding'))
+ embed_node_fn=_make_embed(self.latent_dim, name='node_embedding'),
+ embed_edge_fn=_make_embed(self.latent_dim, name='edge_embedding'),
+ )
graph = embedder(graph)
if self.activation_fn_name == 'relu':
@@ -67,25 +68,30 @@ def __call__(self, graph, train, dropout_rate=DROPOUT_RATE):
activation_fn = nn.silu
else:
raise ValueError(
- f'Invalid activation function name: {self.activation_fn_name}')
+ f'Invalid activation function name: {self.activation_fn_name}'
+ )
for _ in range(self.num_message_passing_steps):
net = jraph.GraphNetwork(
- update_edge_fn=_make_mlp(
- self.hidden_dims,
- activation_fn=activation_fn,
- train=train,
- dropout_rate=dropout_rate),
- update_node_fn=_make_mlp(
- self.hidden_dims,
- activation_fn=activation_fn,
- train=train,
- dropout_rate=dropout_rate),
- update_global_fn=_make_mlp(
- self.hidden_dims,
- activation_fn=activation_fn,
- train=train,
- dropout_rate=dropout_rate))
+ update_edge_fn=_make_mlp(
+ self.hidden_dims,
+ activation_fn=activation_fn,
+ train=train,
+ dropout_rate=dropout_rate,
+ ),
+ update_node_fn=_make_mlp(
+ self.hidden_dims,
+ activation_fn=activation_fn,
+ train=train,
+ dropout_rate=dropout_rate,
+ ),
+ update_global_fn=_make_mlp(
+ self.hidden_dims,
+ activation_fn=activation_fn,
+ train=train,
+ dropout_rate=dropout_rate,
+ ),
+ )
graph = net(graph)
diff --git a/algoperf/workloads/ogbg/ogbg_jax/workload.py b/algoperf/workloads/ogbg/ogbg_jax/workload.py
index 0535aea83..8471fcdcc 100644
--- a/algoperf/workloads/ogbg/ogbg_jax/workload.py
+++ b/algoperf/workloads/ogbg/ogbg_jax/workload.py
@@ -1,39 +1,40 @@
"""OGBG workload implemented in Jax."""
+
import functools
from typing import Any, Dict, Tuple
-from flax import jax_utils
import jax
import jax.numpy as jnp
import jraph
import optax
+from flax import jax_utils
-from algoperf import param_utils
-from algoperf import spec
+from algoperf import param_utils, spec
from algoperf.workloads.ogbg import metrics
from algoperf.workloads.ogbg.ogbg_jax import models
from algoperf.workloads.ogbg.workload import BaseOgbgWorkload
class OgbgWorkload(BaseOgbgWorkload):
-
def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
rng, params_rng = jax.random.split(rng, 2)
self._model = models.GNN(
- self._num_outputs,
- activation_fn_name=self.activation_fn_name,
- hidden_dims=self.hidden_dims,
- latent_dim=self.latent_dim,
- num_message_passing_steps=self.num_message_passing_steps)
+ self._num_outputs,
+ activation_fn_name=self.activation_fn_name,
+ hidden_dims=self.hidden_dims,
+ latent_dim=self.latent_dim,
+ num_message_passing_steps=self.num_message_passing_steps,
+ )
init_fn = jax.jit(functools.partial(self._model.init, train=False))
fake_batch = jraph.GraphsTuple(
- n_node=jnp.asarray([1]),
- n_edge=jnp.asarray([1]),
- nodes=jnp.ones((1, 9)),
- edges=jnp.ones((1, 3)),
- globals=jnp.zeros((1, self._num_outputs)),
- senders=jnp.asarray([0]),
- receivers=jnp.asarray([0]))
+ n_node=jnp.asarray([1]),
+ n_edge=jnp.asarray([1]),
+ nodes=jnp.ones((1, 9)),
+ edges=jnp.ones((1, 3)),
+ globals=jnp.zeros((1, self._num_outputs)),
+ senders=jnp.asarray([0]),
+ receivers=jnp.asarray([0]),
+ )
params = init_fn({'params': params_rng}, fake_batch)
params = params['params']
self._param_shapes = param_utils.jax_param_shapes(params)
@@ -44,40 +45,45 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key == 'Dense_17'
def model_fn(
- self,
- params: spec.ParameterContainer,
- augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- mode: spec.ForwardPassMode,
- rng: spec.RandomState,
- update_batch_norm: bool,
- dropout_rate: float = models.DROPOUT_RATE
+ self,
+ params: spec.ParameterContainer,
+ augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ mode: spec.ForwardPassMode,
+ rng: spec.RandomState,
+ update_batch_norm: bool,
+ dropout_rate: float = models.DROPOUT_RATE,
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
"""Get predicted logits from the network for input graphs."""
del update_batch_norm # No BN in the GNN model.
if model_state is not None:
raise ValueError(
- f'Expected model_state to be None, received {model_state}.')
+ f'Expected model_state to be None, received {model_state}.'
+ )
train = mode == spec.ForwardPassMode.TRAIN
- logits = self._model.apply({'params': params},
- augmented_and_preprocessed_input_batch['inputs'],
- rngs={'dropout': rng},
- train=train,
- dropout_rate=dropout_rate)
+ logits = self._model.apply(
+ {'params': params},
+ augmented_and_preprocessed_input_batch['inputs'],
+ rngs={'dropout': rng},
+ train=train,
+ dropout_rate=dropout_rate,
+ )
return logits, None
def _binary_cross_entropy_with_mask(
- self,
- labels: jnp.ndarray,
- logits: jnp.ndarray,
- mask: jnp.ndarray,
- label_smoothing: float = 0.0) -> jnp.ndarray:
+ self,
+ labels: jnp.ndarray,
+ logits: jnp.ndarray,
+ mask: jnp.ndarray,
+ label_smoothing: float = 0.0,
+ ) -> jnp.ndarray:
"""Binary cross entropy loss for logits, with masked elements."""
if not (logits.shape == labels.shape == mask.shape): # pylint: disable=superfluous-parens
raise ValueError(
- f'Shape mismatch between logits ({logits.shape}), targets '
- f'({labels.shape}), and weights ({mask.shape}).')
+ f'Shape mismatch between logits ({logits.shape}), targets '
+ f'({labels.shape}), and weights ({mask.shape}).'
+ )
if len(logits.shape) != 2:
raise ValueError(f'Rank of logits ({logits.shape}) must be 2.')
@@ -93,26 +99,31 @@ def _binary_cross_entropy_with_mask(
positive_logits = logits >= 0
relu_logits = jnp.where(positive_logits, logits, 0)
abs_logits = jnp.where(positive_logits, logits, -logits)
- losses = relu_logits - (logits * smoothed_labels) + (
- jnp.log(1 + jnp.exp(-abs_logits)))
- return jnp.where(mask, losses, 0.)
+ losses = (
+ relu_logits
+ - (logits * smoothed_labels)
+ + (jnp.log(1 + jnp.exp(-abs_logits)))
+ )
+ return jnp.where(mask, losses, 0.0)
def _eval_metric(self, labels, logits, masks):
loss = self.loss_fn(labels, logits, masks)
return metrics.EvalMetrics.single_from_model_output(
- loss=loss['per_example'], logits=logits, labels=labels, mask=masks)
+ loss=loss['per_example'], logits=logits, labels=labels, mask=masks
+ )
@functools.partial(
- jax.pmap,
- axis_name='batch',
- in_axes=(None, 0, 0, 0, None),
- static_broadcasted_argnums=(0,))
+ jax.pmap,
+ axis_name='batch',
+ in_axes=(None, 0, 0, 0, None),
+ static_broadcasted_argnums=(0,),
+ )
def _eval_batch(self, params, batch, model_state, rng):
return super()._eval_batch(params, batch, model_state, rng)
def _normalize_eval_metrics(
- self, num_examples: int, total_metrics: Dict[str,
- Any]) -> Dict[str, float]:
+ self, num_examples: int, total_metrics: Dict[str, Any]
+ ) -> Dict[str, float]:
"""Normalize eval metrics."""
del num_examples
total_metrics = total_metrics.reduce()
@@ -120,7 +131,6 @@ def _normalize_eval_metrics(
class OgbgGeluWorkload(OgbgWorkload):
-
@property
def activation_fn_name(self) -> str:
"""Name of the activation function to use. One of 'relu', 'gelu', 'silu'."""
@@ -136,7 +146,6 @@ def test_target_value(self) -> float:
class OgbgSiluWorkload(OgbgWorkload):
-
@property
def activation_fn_name(self) -> str:
"""Name of the activation function to use. One of 'relu', 'gelu', 'silu'."""
@@ -152,7 +161,6 @@ def test_target_value(self) -> float:
class OgbgModelSizeWorkload(OgbgWorkload):
-
@property
def hidden_dims(self) -> Tuple[int]:
return (256, 256)
diff --git a/algoperf/workloads/ogbg/ogbg_pytorch/models.py b/algoperf/workloads/ogbg/ogbg_pytorch/models.py
index 8a40bef58..a69bc6ee1 100644
--- a/algoperf/workloads/ogbg/ogbg_pytorch/models.py
+++ b/algoperf/workloads/ogbg/ogbg_pytorch/models.py
@@ -4,13 +4,12 @@
from typing import Callable, Optional, Tuple
import jax.tree_util as tree
-from jraph import GraphsTuple
import torch
+from jraph import GraphsTuple
from torch import nn
from algoperf import init_utils
-from algoperf.pytorch_utils import CustomDropout
-from algoperf.pytorch_utils import SequentialWithDropout
+from algoperf.pytorch_utils import CustomDropout, SequentialWithDropout
DROPOUT_RATE = 0.1
@@ -19,8 +18,9 @@ def _make_mlp(in_dim, hidden_dims, activation_fn):
"""Creates a MLP with specified dimensions."""
layers = SequentialWithDropout()
for i, dim in enumerate(hidden_dims):
- layers.add_module(f'dense_{i}',
- nn.Linear(in_features=in_dim, out_features=dim))
+ layers.add_module(
+ f'dense_{i}', nn.Linear(in_features=in_dim, out_features=dim)
+ )
layers.add_module(f'norm_{i}', nn.LayerNorm(dim, eps=1e-6))
layers.add_module(f'activation_fn_{i}', activation_fn())
layers.add_module(f'dropout_{i}', CustomDropout())
@@ -35,12 +35,14 @@ class GNN(nn.Module):
variables. The final prediction will be encoded in the globals.
"""
- def __init__(self,
- num_outputs: int = 128,
- activation_fn_name: str = 'relu',
- latent_dim: int = 256,
- hidden_dims: Tuple[int] = (256,),
- num_message_passing_steps: int = 5) -> None:
+ def __init__(
+ self,
+ num_outputs: int = 128,
+ activation_fn_name: str = 'relu',
+ latent_dim: int = 256,
+ hidden_dims: Tuple[int] = (256,),
+ num_message_passing_steps: int = 5,
+ ) -> None:
super().__init__()
self.latent_dim = latent_dim
self.hidden_dims = hidden_dims
@@ -58,7 +60,8 @@ def __init__(self,
activation_fn = nn.SiLU
else:
raise ValueError(
- f'Invalid activation function name: {self.activation_fn_name}')
+ f'Invalid activation function name: {self.activation_fn_name}'
+ )
graph_network_layers = []
for st in range(self.num_message_passing_steps):
@@ -66,8 +69,9 @@ def __init__(self,
# specifically update_edge_fn update_node_fn and update_global_fn.
if st == 0:
in_dim_edge_fn = self.latent_dim * 3 + self.num_outputs
- in_dim_node_fn = self.latent_dim + self.hidden_dims[
- -1] * 2 + self.num_outputs
+ in_dim_node_fn = (
+ self.latent_dim + self.hidden_dims[-1] * 2 + self.num_outputs
+ )
last_in_dim = self.hidden_dims[-1] * 2 + self.num_outputs
else:
in_dim_edge_fn = self.hidden_dims[-1] * 4
@@ -75,32 +79,36 @@ def __init__(self,
last_in_dim = self.hidden_dims[-1] * 3
graph_network_layers.append(
- GraphNetwork(
- update_edge_fn=_make_mlp(in_dim_edge_fn,
- self.hidden_dims,
- activation_fn),
- update_node_fn=_make_mlp(in_dim_node_fn,
- self.hidden_dims,
- activation_fn),
- update_global_fn=_make_mlp(last_in_dim,
- self.hidden_dims,
- activation_fn)))
+ GraphNetwork(
+ update_edge_fn=_make_mlp(
+ in_dim_edge_fn, self.hidden_dims, activation_fn
+ ),
+ update_node_fn=_make_mlp(
+ in_dim_node_fn, self.hidden_dims, activation_fn
+ ),
+ update_global_fn=_make_mlp(
+ last_in_dim, self.hidden_dims, activation_fn
+ ),
+ )
+ )
self.graph_network = SequentialWithDropout(*graph_network_layers)
self.decoder = nn.Linear(
- in_features=self.hidden_dims[-1], out_features=self.num_outputs)
+ in_features=self.hidden_dims[-1], out_features=self.num_outputs
+ )
for m in self.modules():
if isinstance(m, nn.Linear):
init_utils.pytorch_default_init(m)
- def forward(self,
- graph: GraphsTuple,
- dropout_rate: float = DROPOUT_RATE) -> torch.Tensor:
-
+ def forward(
+ self, graph: GraphsTuple, dropout_rate: float = DROPOUT_RATE
+ ) -> torch.Tensor:
graph = graph._replace(
- globals=torch.zeros([graph.n_node.shape[0], self.num_outputs],
- device=graph.n_node.device))
+ globals=torch.zeros(
+ [graph.n_node.shape[0], self.num_outputs], device=graph.n_node.device
+ )
+ )
graph = graph._replace(nodes=self.node_embedder(graph.nodes))
graph = graph._replace(edges=self.edge_embedder(graph.edges))
@@ -138,10 +146,12 @@ class GraphNetwork(nn.Module):
A method that applies the configured GraphNetwork.
"""
- def __init__(self,
- update_edge_fn: Optional[Callable] = None,
- update_node_fn: Optional[Callable] = None,
- update_global_fn: Optional[Callable] = None) -> None:
+ def __init__(
+ self,
+ update_edge_fn: Optional[Callable] = None,
+ update_node_fn: Optional[Callable] = None,
+ update_global_fn: Optional[Callable] = None,
+ ) -> None:
super().__init__()
self.update_edge_fn = update_edge_fn
self.update_node_fn = update_node_fn
@@ -168,35 +178,41 @@ def forward(self, graph: GraphsTuple, dropout_rate: float) -> GraphsTuple:
nodes, edges, receivers, senders, globals_, n_node, n_edge = graph
sum_n_node = tree.tree_leaves(nodes)[0].shape[0]
if not tree.tree_all(
- tree.tree_map(lambda n: n.shape[0] == sum_n_node, nodes)):
+ tree.tree_map(lambda n: n.shape[0] == sum_n_node, nodes)
+ ):
raise ValueError(
- 'All node arrays in nest must contain the same number of nodes.')
+ 'All node arrays in nest must contain the same number of nodes.'
+ )
sent_attributes = tree.tree_map(lambda n: n[senders], nodes)
received_attributes = tree.tree_map(lambda n: n[receivers], nodes)
# Here we scatter the global features to the corresponding edges,
# giving us tensors of shape [num_edges, global_feat].
global_edge_attributes = tree.tree_map(
- lambda g: torch.repeat_interleave(g, n_edge, dim=0), globals_)
+ lambda g: torch.repeat_interleave(g, n_edge, dim=0), globals_
+ )
if self.update_edge_fn:
edge_fn_inputs = torch.cat(
- [edges, sent_attributes, received_attributes, global_edge_attributes],
- dim=-1)
+ [edges, sent_attributes, received_attributes, global_edge_attributes],
+ dim=-1,
+ )
edges = self.update_edge_fn(edge_fn_inputs, dropout_rate)
if self.update_node_fn:
sent_attributes = tree.tree_map(
- lambda e: scatter_sum(e, senders, dim=0, dim_size=sum_n_node), edges)
+ lambda e: scatter_sum(e, senders, dim=0, dim_size=sum_n_node), edges
+ )
received_attributes = tree.tree_map(
- lambda e: scatter_sum(e, receivers, dim=0, dim_size=sum_n_node),
- edges)
+ lambda e: scatter_sum(e, receivers, dim=0, dim_size=sum_n_node), edges
+ )
# Here we scatter the global features to the corresponding nodes,
# giving us tensors of shape [num_nodes, global_feat].
global_attributes = tree.tree_map(
- lambda g: torch.repeat_interleave(g, n_node, dim=0), globals_)
+ lambda g: torch.repeat_interleave(g, n_node, dim=0), globals_
+ )
node_fn_inputs = torch.cat(
- [nodes, sent_attributes, received_attributes, global_attributes],
- dim=-1)
+ [nodes, sent_attributes, received_attributes, global_attributes], dim=-1
+ )
nodes = self.update_node_fn(node_fn_inputs, dropout_rate)
if self.update_global_fn:
@@ -210,31 +226,37 @@ def forward(self, graph: GraphsTuple, dropout_rate: float) -> GraphsTuple:
edge_gr_idx = torch.repeat_interleave(graph_idx, n_edge, dim=0)
# We use the aggregation function to pool the nodes/edges per graph.
node_attributes = tree.tree_map(
- lambda n: scatter_sum(n, node_gr_idx, dim=0, dim_size=n_graph), nodes)
+ lambda n: scatter_sum(n, node_gr_idx, dim=0, dim_size=n_graph), nodes
+ )
edge_attributes = tree.tree_map(
- lambda e: scatter_sum(e, edge_gr_idx, dim=0, dim_size=n_graph), edges)
+ lambda e: scatter_sum(e, edge_gr_idx, dim=0, dim_size=n_graph), edges
+ )
# These pooled nodes are the inputs to the global update fn.
- global_fn_inputs = torch.cat([node_attributes, edge_attributes, globals_],
- dim=-1)
+ global_fn_inputs = torch.cat(
+ [node_attributes, edge_attributes, globals_], dim=-1
+ )
globals_ = self.update_global_fn(global_fn_inputs, dropout_rate)
return GraphsTuple(
- nodes=nodes,
- edges=edges,
- receivers=receivers,
- senders=senders,
- globals=globals_,
- n_node=n_node,
- n_edge=n_edge)
+ nodes=nodes,
+ edges=edges,
+ receivers=receivers,
+ senders=senders,
+ globals=globals_,
+ n_node=n_node,
+ n_edge=n_edge,
+ )
# Forked from
# github.com/rusty1s/pytorch_scatter/blob/master/torch_scatter/scatter.py.
-def scatter_sum(src: torch.Tensor,
- index: torch.Tensor,
- dim: int = -1,
- out: Optional[torch.Tensor] = None,
- dim_size: Optional[int] = None) -> torch.Tensor:
+def scatter_sum(
+ src: torch.Tensor,
+ index: torch.Tensor,
+ dim: int = -1,
+ out: Optional[torch.Tensor] = None,
+ dim_size: Optional[int] = None,
+) -> torch.Tensor:
r"""
|
.. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
diff --git a/algoperf/workloads/ogbg/ogbg_pytorch/workload.py b/algoperf/workloads/ogbg/ogbg_pytorch/workload.py
index a45a93668..f72ff5141 100644
--- a/algoperf/workloads/ogbg/ogbg_pytorch/workload.py
+++ b/algoperf/workloads/ogbg/ogbg_pytorch/workload.py
@@ -1,16 +1,15 @@
"""OGBG workload implemented in PyTorch."""
+
import contextlib
from typing import Any, Callable, Dict, Optional, Tuple
import jax
-from jraph import GraphsTuple
import torch
import torch.distributed as dist
+from jraph import GraphsTuple
from torch.nn.parallel import DistributedDataParallel as DDP
-from algoperf import param_utils
-from algoperf import pytorch_utils
-from algoperf import spec
+from algoperf import param_utils, pytorch_utils, spec
from algoperf.workloads.ogbg import metrics
from algoperf.workloads.ogbg.ogbg_pytorch import models
from algoperf.workloads.ogbg.ogbg_pytorch.models import GNN
@@ -23,9 +22,11 @@ def _pytorch_map(inputs: Any) -> Any:
if USE_PYTORCH_DDP:
return jax.tree.map(lambda a: torch.as_tensor(a, device=DEVICE), inputs)
return jax.tree.map(
- lambda a: torch.as_tensor(a, device=DEVICE).view(-1, a.shape[-1])
- if len(a.shape) == 3 else torch.as_tensor(a, device=DEVICE).view(-1),
- inputs)
+ lambda a: torch.as_tensor(a, device=DEVICE).view(-1, a.shape[-1])
+ if len(a.shape) == 3
+ else torch.as_tensor(a, device=DEVICE).view(-1),
+ inputs,
+ )
def _shard(inputs: Any) -> Any:
@@ -36,44 +37,47 @@ def _shard(inputs: Any) -> Any:
def _graph_map(function: Callable, graph: GraphsTuple) -> GraphsTuple:
return GraphsTuple(
- nodes=function(graph.nodes),
- edges=function(graph.edges),
- receivers=function(graph.receivers),
- senders=function(graph.senders),
- globals=function(graph.globals),
- n_node=function(graph.n_node),
- n_edge=function(graph.n_edge))
+ nodes=function(graph.nodes),
+ edges=function(graph.edges),
+ receivers=function(graph.receivers),
+ senders=function(graph.senders),
+ globals=function(graph.globals),
+ n_node=function(graph.n_node),
+ n_edge=function(graph.n_edge),
+ )
class OgbgWorkload(BaseOgbgWorkload):
-
# Does NOT apply regularization, which is left to the submitter to do in
# `update_params`.
def loss_fn(
- self,
- label_batch: spec.Tensor, # Dense or one-hot labels.
- logits_batch: spec.Tensor,
- mask_batch: Optional[spec.Tensor] = None,
- label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable
+ self,
+ label_batch: spec.Tensor, # Dense or one-hot labels.
+ logits_batch: spec.Tensor,
+ mask_batch: Optional[spec.Tensor] = None,
+ label_smoothing: float = 0.0,
+ ) -> Dict[str, spec.Tensor]: # differentiable
"""Evaluate the (masked) loss function at (label_batch, logits_batch).
Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of
valid examples in batch, 'per_example': 1-d array of per-example losses}
(not synced across devices).
"""
- loss_dict = super().loss_fn(label_batch,
- logits_batch,
- mask_batch,
- label_smoothing)
+ loss_dict = super().loss_fn(
+ label_batch, logits_batch, mask_batch, label_smoothing
+ )
loss_dict['n_valid_examples'] = torch.as_tensor(
- loss_dict['n_valid_examples'], device=DEVICE)
+ loss_dict['n_valid_examples'], device=DEVICE
+ )
return loss_dict
- def _build_input_queue(self,
- data_rng: jax.random.PRNGKey,
- split: str,
- data_dir: str,
- global_batch_size: int):
+ def _build_input_queue(
+ self,
+ data_rng: jax.random.PRNGKey,
+ split: str,
+ data_dir: str,
+ global_batch_size: int,
+ ):
# TODO: Check where the + 1 comes from.
per_device_batch_size = int(global_batch_size / N_GPUS) + 1
@@ -81,10 +85,9 @@ def _build_input_queue(self,
# avoid creating too many threads.
if RANK == 0:
data_rng = data_rng.astype('uint32')
- dataset_iter = super()._build_input_queue(data_rng,
- split,
- data_dir,
- global_batch_size)
+ dataset_iter = super()._build_input_queue(
+ data_rng, split, data_dir, global_batch_size
+ )
while True:
if RANK == 0:
@@ -92,14 +95,16 @@ def _build_input_queue(self,
graph = _graph_map(_pytorch_map, batch['inputs'])
targets = torch.as_tensor(batch['targets'], device=DEVICE)
weights = torch.as_tensor(
- batch['weights'], dtype=torch.bool, device=DEVICE)
+ batch['weights'], dtype=torch.bool, device=DEVICE
+ )
# Send batch to other devices when using DDP.
if USE_PYTORCH_DDP:
dist.broadcast_object_list([graph], src=0, device=DEVICE)
# During eval, the batch size of the remainder might be different.
if split != 'train':
per_device_batch_size = torch.tensor(
- len(targets[0]), dtype=torch.int32, device=DEVICE)
+ len(targets[0]), dtype=torch.int32, device=DEVICE
+ )
dist.broadcast(per_device_batch_size, src=0)
dist.broadcast(targets, src=0)
targets = targets[0]
@@ -114,25 +119,27 @@ def _build_input_queue(self,
graph = graph[0]
# During eval, the batch size of the remainder might be different.
if split != 'train':
- per_device_batch_size = torch.empty((1,),
- dtype=torch.int32,
- device=DEVICE)
+ per_device_batch_size = torch.empty(
+ (1,), dtype=torch.int32, device=DEVICE
+ )
dist.broadcast(per_device_batch_size, src=0)
targets = torch.empty(
- (N_GPUS, per_device_batch_size, self._num_outputs), device=DEVICE)
+ (N_GPUS, per_device_batch_size, self._num_outputs), device=DEVICE
+ )
dist.broadcast(targets, src=0)
targets = targets[RANK]
weights = torch.empty(
- (N_GPUS, per_device_batch_size, self._num_outputs),
- dtype=torch.bool,
- device=DEVICE)
+ (N_GPUS, per_device_batch_size, self._num_outputs),
+ dtype=torch.bool,
+ device=DEVICE,
+ )
dist.broadcast(weights, src=0)
weights = weights[RANK]
batch = {
- 'inputs': _graph_map(_shard, graph),
- 'targets': targets,
- 'weights': weights,
+ 'inputs': _graph_map(_shard, graph),
+ 'targets': targets,
+ 'weights': weights,
}
yield batch
@@ -140,11 +147,12 @@ def _build_input_queue(self,
def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
torch.random.manual_seed(rng[0])
model = GNN(
- num_outputs=self._num_outputs,
- hidden_dims=self.hidden_dims,
- latent_dim=self.latent_dim,
- num_message_passing_steps=self.num_message_passing_steps,
- activation_fn_name=self.activation_fn_name)
+ num_outputs=self._num_outputs,
+ hidden_dims=self.hidden_dims,
+ latent_dim=self.latent_dim,
+ num_message_passing_steps=self.num_message_passing_steps,
+ activation_fn_name=self.activation_fn_name,
+ )
self._param_shapes = param_utils.pytorch_param_shapes(model)
self._param_types = param_utils.pytorch_param_types(self._param_shapes)
model.to(DEVICE)
@@ -159,21 +167,22 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key in ['decoder.weight', 'decoder.bias']
def model_fn(
- self,
- params: spec.ParameterContainer,
- augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- mode: spec.ForwardPassMode,
- rng: spec.RandomState,
- update_batch_norm: bool,
- dropout_rate: float = models.DROPOUT_RATE
+ self,
+ params: spec.ParameterContainer,
+ augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ mode: spec.ForwardPassMode,
+ rng: spec.RandomState,
+ update_batch_norm: bool,
+ dropout_rate: float = models.DROPOUT_RATE,
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
"""Get predicted logits from the network for input graphs."""
del rng
del update_batch_norm # No BN in the GNN model.
if model_state is not None:
raise ValueError(
- f'Expected model_state to be None, received {model_state}.')
+ f'Expected model_state to be None, received {model_state}.'
+ )
model = params
if mode == spec.ForwardPassMode.TRAIN:
@@ -182,28 +191,31 @@ def model_fn(
model.eval()
contexts = {
- spec.ForwardPassMode.EVAL: torch.no_grad,
- spec.ForwardPassMode.TRAIN: contextlib.nullcontext,
+ spec.ForwardPassMode.EVAL: torch.no_grad,
+ spec.ForwardPassMode.TRAIN: contextlib.nullcontext,
}
with contexts[mode]():
logits = model(
- augmented_and_preprocessed_input_batch['inputs'],
- dropout_rate=dropout_rate)
+ augmented_and_preprocessed_input_batch['inputs'],
+ dropout_rate=dropout_rate,
+ )
return logits, None
def _binary_cross_entropy_with_mask(
- self,
- labels: torch.Tensor,
- logits: torch.Tensor,
- mask: torch.Tensor,
- label_smoothing: float = 0.0) -> torch.Tensor:
+ self,
+ labels: torch.Tensor,
+ logits: torch.Tensor,
+ mask: torch.Tensor,
+ label_smoothing: float = 0.0,
+ ) -> torch.Tensor:
"""Binary cross entropy loss for logits, with masked elements."""
if not (logits.shape == labels.shape == mask.shape): # pylint: disable=superfluous-parens
raise ValueError(
- f'Shape mismatch between logits ({logits.shape}), targets '
- f'({labels.shape}), and weights ({mask.shape}).')
+ f'Shape mismatch between logits ({logits.shape}), targets '
+ f'({labels.shape}), and weights ({mask.shape}).'
+ )
if len(logits.shape) != 2:
raise ValueError(f'Rank of logits ({logits.shape}) must be 2.')
@@ -213,36 +225,40 @@ def _binary_cross_entropy_with_mask(
# Apply label_smoothing.
num_classes = labels.shape[-1]
- smoothed_labels = ((1.0 - label_smoothing) * labels +
- label_smoothing / num_classes)
+ smoothed_labels = (
+ 1.0 - label_smoothing
+ ) * labels + label_smoothing / num_classes
# Numerically stable implementation of BCE loss.
# This mimics TensorFlow's tf.nn.sigmoid_cross_entropy_with_logits().
positive_logits = logits >= 0
relu_logits = torch.where(positive_logits, logits, 0)
abs_logits = torch.where(positive_logits, logits, -logits)
- losses = relu_logits - (logits * smoothed_labels) + (
- torch.log(1 + torch.exp(-abs_logits)))
- return torch.where(mask.to(torch.bool), losses, 0.)
+ losses = (
+ relu_logits
+ - (logits * smoothed_labels)
+ + (torch.log(1 + torch.exp(-abs_logits)))
+ )
+ return torch.where(mask.to(torch.bool), losses, 0.0)
def _eval_metric(self, labels, logits, masks):
loss = self.loss_fn(labels, logits, masks)
return metrics.EvalMetrics.single_from_model_output(
- loss=loss['per_example'].cpu().numpy(),
- logits=logits.cpu().numpy(),
- labels=labels.cpu().numpy(),
- mask=masks.cpu().numpy())
+ loss=loss['per_example'].cpu().numpy(),
+ logits=logits.cpu().numpy(),
+ labels=labels.cpu().numpy(),
+ mask=masks.cpu().numpy(),
+ )
def _normalize_eval_metrics(
- self, num_examples: int, total_metrics: Dict[str,
- Any]) -> Dict[str, float]:
+ self, num_examples: int, total_metrics: Dict[str, Any]
+ ) -> Dict[str, float]:
"""Normalize eval metrics."""
del num_examples
return {k: float(v) for k, v in total_metrics.compute().items()}
class OgbgGeluWorkload(OgbgWorkload):
-
@property
def activation_fn_name(self) -> str:
"""Name of the activation function to use. One of 'relu', 'gelu', 'silu'."""
@@ -258,7 +274,6 @@ def test_target_value(self) -> float:
class OgbgSiluWorkload(OgbgWorkload):
-
@property
def activation_fn_name(self) -> str:
"""Name of the activation function to use. One of 'relu', 'gelu', 'silu'."""
@@ -274,7 +289,6 @@ def test_target_value(self) -> float:
class OgbgModelSizeWorkload(OgbgWorkload):
-
@property
def hidden_dims(self) -> Tuple[int]:
return (256, 256)
diff --git a/algoperf/workloads/ogbg/workload.py b/algoperf/workloads/ogbg/workload.py
index 971e7f0f6..1d182fed5 100644
--- a/algoperf/workloads/ogbg/workload.py
+++ b/algoperf/workloads/ogbg/workload.py
@@ -14,7 +14,6 @@
class BaseOgbgWorkload(spec.Workload):
-
_num_outputs: int = 128
@property
@@ -40,8 +39,10 @@ def num_message_passing_steps(self) -> int:
return 5
def has_reached_validation_target(self, eval_result: float) -> bool:
- return eval_result[
- 'validation/mean_average_precision'] > self.validation_target_value
+ return (
+ eval_result['validation/mean_average_precision']
+ > self.validation_target_value
+ )
@property
def validation_target_value(self) -> float:
@@ -94,15 +95,16 @@ def max_allowed_runtime_sec(self) -> int:
def eval_period_time_sec(self) -> int:
return 4 * 60
- def _build_input_queue(self,
- data_rng: jax.random.PRNGKey,
- split: str,
- data_dir: str,
- global_batch_size: int):
- dataset_iter = input_pipeline.get_dataset_iter(split,
- data_rng,
- data_dir,
- global_batch_size)
+ def _build_input_queue(
+ self,
+ data_rng: jax.random.PRNGKey,
+ split: str,
+ data_dir: str,
+ global_batch_size: int,
+ ):
+ dataset_iter = input_pipeline.get_dataset_iter(
+ split, data_rng, data_dir, global_batch_size
+ )
if split != 'train':
# Note that this stores the entire val dataset in memory.
dataset_iter = itertools.cycle(dataset_iter)
@@ -111,11 +113,12 @@ def _build_input_queue(self,
# Does NOT apply regularization, which is left to the submitter to do in
# `update_params`.
def loss_fn(
- self,
- label_batch: spec.Tensor, # Dense or one-hot labels.
- logits_batch: spec.Tensor,
- mask_batch: Optional[spec.Tensor] = None,
- label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable
+ self,
+ label_batch: spec.Tensor, # Dense or one-hot labels.
+ logits_batch: spec.Tensor,
+ mask_batch: Optional[spec.Tensor] = None,
+ label_smoothing: float = 0.0,
+ ) -> Dict[str, spec.Tensor]: # differentiable
"""Evaluate the (masked) loss function at (label_batch, logits_batch).
Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of
@@ -123,19 +126,20 @@ def loss_fn(
(not synced across devices).
"""
per_example_losses = self._binary_cross_entropy_with_mask(
- labels=label_batch,
- logits=logits_batch,
- mask=mask_batch,
- label_smoothing=label_smoothing)
+ labels=label_batch,
+ logits=logits_batch,
+ mask=mask_batch,
+ label_smoothing=label_smoothing,
+ )
if mask_batch is not None:
n_valid_examples = mask_batch.sum()
else:
n_valid_examples = len(per_example_losses)
summed_loss = per_example_losses.sum()
return {
- 'summed': summed_loss,
- 'n_valid_examples': n_valid_examples,
- 'per_example': per_example_losses,
+ 'summed': summed_loss,
+ 'n_valid_examples': n_valid_examples,
+ 'per_example': per_example_losses,
}
@property
@@ -145,39 +149,45 @@ def step_hint(self) -> int:
@abc.abstractmethod
def _normalize_eval_metrics(
- self, num_examples: int, total_metrics: Dict[str,
- Any]) -> Dict[str, float]:
+ self, num_examples: int, total_metrics: Dict[str, Any]
+ ) -> Dict[str, float]:
"""Normalize eval metrics."""
- def _eval_batch(self,
- params: spec.ParameterContainer,
- batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- rng: spec.RandomState) -> metrics.EvalMetrics:
+ def _eval_batch(
+ self,
+ params: spec.ParameterContainer,
+ batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ rng: spec.RandomState,
+ ) -> metrics.EvalMetrics:
logits, _ = self.model_fn(
- params,
- batch,
- model_state,
- spec.ForwardPassMode.EVAL,
- rng,
- update_batch_norm=False)
+ params,
+ batch,
+ model_state,
+ spec.ForwardPassMode.EVAL,
+ rng,
+ update_batch_norm=False,
+ )
return self._eval_metric(batch['targets'], logits, batch['weights'])
- def _eval_model_on_split(self,
- split: str,
- num_examples: int,
- global_batch_size: int,
- params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- rng: spec.RandomState,
- data_dir: str,
- global_step: int = 0) -> Dict[str, float]:
+ def _eval_model_on_split(
+ self,
+ split: str,
+ num_examples: int,
+ global_batch_size: int,
+ params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ rng: spec.RandomState,
+ data_dir: str,
+ global_step: int = 0,
+ ) -> Dict[str, float]:
"""Run a full evaluation of the model."""
del global_step
data_rng, model_rng = prng.split(rng, 2)
if split not in self._eval_iters:
self._eval_iters[split] = self._build_input_queue(
- data_rng, split, data_dir, global_batch_size=global_batch_size)
+ data_rng, split, data_dir, global_batch_size=global_batch_size
+ )
total_metrics = None
num_eval_steps = int(math.ceil(float(num_examples) / global_batch_size))
@@ -186,8 +196,10 @@ def _eval_model_on_split(self,
batch = next(self._eval_iters[split])
batch_metrics = self._eval_batch(params, batch, model_state, model_rng)
total_metrics = (
- batch_metrics
- if total_metrics is None else total_metrics.merge(batch_metrics))
+ batch_metrics
+ if total_metrics is None
+ else total_metrics.merge(batch_metrics)
+ )
if total_metrics is None:
return {}
return self._normalize_eval_metrics(num_examples, total_metrics)
diff --git a/algoperf/workloads/utils.py b/algoperf/workloads/utils.py
index 7719f91fb..920c3cf46 100644
--- a/algoperf/workloads/utils.py
+++ b/algoperf/workloads/utils.py
@@ -5,10 +5,12 @@
def print_jax_model_summary(model, fake_inputs):
"""Prints a summary of the jax module."""
tabulate_fn = nn.tabulate(
- model,
- jax.random.PRNGKey(0),
- console_kwargs={
- 'force_terminal': False, 'force_jupyter': False, 'width': 240
- },
+ model,
+ jax.random.PRNGKey(0),
+ console_kwargs={
+ 'force_terminal': False,
+ 'force_jupyter': False,
+ 'width': 240,
+ },
)
print(tabulate_fn(fake_inputs, train=False))
diff --git a/algoperf/workloads/wmt/bleu.py b/algoperf/workloads/wmt/bleu.py
index ad314a7d3..e0064ba51 100644
--- a/algoperf/workloads/wmt/bleu.py
+++ b/algoperf/workloads/wmt/bleu.py
@@ -30,11 +30,11 @@
def my_log(num):
"""
- Floors the log function
+ Floors the log function
- :param num: the number
- :return: log(num) floored to a very low number
- """
+ :param num: the number
+ :return: log(num) floored to a very low number
+ """
if num == 0.0:
return -9999999999
@@ -43,12 +43,12 @@ def my_log(num):
def tokenize_13a(line):
"""
- Tokenizes an input line using a relatively minimal tokenization that is
- however equivalent to mteval-v13a, used by WMT.
+ Tokenizes an input line using a relatively minimal tokenization that is
+ however equivalent to mteval-v13a, used by WMT.
- :param line: a segment to tokenize
- :return: the tokenized line
- """
+ :param line: a segment to tokenize
+ :return: the tokenized line
+ """
norm = line
@@ -62,14 +62,17 @@ def tokenize_13a(line):
norm = norm.replace('>', '>')
# language-dependent part (assuming Western languages):
- norm = " {} ".format(norm)
+ norm = ' {} '.format(norm)
norm = re.sub(r'([\{-\~\[-\` -\&\(-\+\:-\@\/])', ' \\1 ', norm)
- norm = re.sub(r'([^0-9])([\.,])', '\\1 \\2 ',
- norm) # tokenize period and comma unless preceded by a digit
- norm = re.sub(r'([\.,])([^0-9])', ' \\1 \\2',
- norm) # tokenize period and comma unless followed by a digit
- norm = re.sub(r'([0-9])(-)', '\\1 \\2 ',
- norm) # tokenize dash when preceded by a digit
+ norm = re.sub(
+ r'([^0-9])([\.,])', '\\1 \\2 ', norm
+ ) # tokenize period and comma unless preceded by a digit
+ norm = re.sub(
+ r'([\.,])([^0-9])', ' \\1 \\2', norm
+ ) # tokenize period and comma unless followed by a digit
+ norm = re.sub(
+ r'([0-9])(-)', '\\1 \\2 ', norm
+ ) # tokenize dash when preceded by a digit
norm = re.sub(r'\s+', ' ', norm) # one space only between words
norm = re.sub(r'^\s+', '', norm) # no leading space
norm = re.sub(r'\s+$', '', norm) # no trailing space
@@ -80,14 +83,15 @@ def tokenize_13a(line):
class UnicodeRegex:
"""Ad-hoc hack to recognize all punctuation and symbols.
- without depending on https://pypi.python.org/pypi/regex/."""
+ without depending on https://pypi.python.org/pypi/regex/."""
@staticmethod
def _property_chars(prefix):
return ''.join(
- chr(x)
- for x in range(sys.maxunicode)
- if unicodedata.category(chr(x)).startswith(prefix))
+ chr(x)
+ for x in range(sys.maxunicode)
+ if unicodedata.category(chr(x)).startswith(prefix)
+ )
punctuation = _property_chars('P')
nondigit_punct_re = re.compile(r'([^\d])([' + punctuation + r'])')
@@ -98,27 +102,27 @@ def _property_chars(prefix):
def tokenize_v14_international(string):
r"""Tokenize a string following the official BLEU implementation.
- See
- https://github.com/moses-smt/mosesdecoder/blob/master/scripts/generic/mteval-v14.pl#L954-L983
- In our case, the input string is expected to be just one line
- and no HTML entities de-escaping is needed.
- So we just tokenize on punctuation and symbols,
- except when a punctuation is preceded and followed by a digit
- (e.g. a comma/dot as a thousand/decimal separator).
-
- Note that a number (e.g., a year) followed by a dot at the end of sentence
- is NOT tokenized,
- i.e. the dot stays with the number because `s/(\p{P})(\P{N})/ $1 $2/g`
- does not match this case (unless we add a space after each sentence).
- However, this error is already in the original mteval-v14.pl
- and we want to be consistent with it.
- The error is not present in the non-international version,
- which uses,
- `$norm_text = " $norm_text "` (or `norm = " {} ".format(norm)` in Python).
-
- :param string: the input string
- :return: a list of tokens
- """
+ See
+ https://github.com/moses-smt/mosesdecoder/blob/master/scripts/generic/mteval-v14.pl#L954-L983
+ In our case, the input string is expected to be just one line
+ and no HTML entities de-escaping is needed.
+ So we just tokenize on punctuation and symbols,
+ except when a punctuation is preceded and followed by a digit
+ (e.g. a comma/dot as a thousand/decimal separator).
+
+ Note that a number (e.g., a year) followed by a dot at the end of sentence
+ is NOT tokenized,
+ i.e. the dot stays with the number because `s/(\p{P})(\P{N})/ $1 $2/g`
+ does not match this case (unless we add a space after each sentence).
+ However, this error is already in the original mteval-v14.pl
+ and we want to be consistent with it.
+ The error is not present in the non-international version,
+ which uses,
+ `$norm_text = " $norm_text "` (or `norm = " {} ".format(norm)` in Python).
+
+ :param string: the input string
+ :return: a list of tokens
+ """
string = UnicodeRegex.nondigit_punct_re.sub(r'\1 \2 ', string)
string = UnicodeRegex.punct_nondigit_re.sub(r' \1 \2', string)
string = UnicodeRegex.symbol_re.sub(r' \1 ', string)
@@ -127,94 +131,94 @@ def tokenize_v14_international(string):
def tokenize_zh(sentence):
"""MIT License
- Copyright (c) 2017 - Shujian Huang
-
- Permission is hereby granted, free of charge, to any person obtaining
- a copy of this software and associated documentation files
- (the "Software"), to deal in the Software without restriction, including
- without limitation the rights to use, copy, modify, merge, publish,
- distribute, sublicense, and/or sell copies of the Software, and to
- permit persons to whom the Software is furnished to do so, subject to the
- following conditions:
-
- The above copyright notice and this permission notice shall be included
- in all copies or substantial portions of the Software.
-
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
- OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
- MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
- IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
- DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
- OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
- USE OR OTHER DEALINGS IN THE SOFTWARE.
-
- The tokenization of Chinese text in this script contains two steps:
- separate each Chinese characters (by utf-8 encoding);
- tokenize the non Chinese part (following the mteval script).
- Author: Shujian Huang huangsj@nju.edu.cn
-
- :param sentence: input sentence
- :return: tokenized sentence
- """
+ Copyright (c) 2017 - Shujian Huang
+
+ Permission is hereby granted, free of charge, to any person obtaining
+ a copy of this software and associated documentation files
+ (the "Software"), to deal in the Software without restriction, including
+ without limitation the rights to use, copy, modify, merge, publish,
+ distribute, sublicense, and/or sell copies of the Software, and to
+ permit persons to whom the Software is furnished to do so, subject to the
+ following conditions:
+
+ The above copyright notice and this permission notice shall be included
+ in all copies or substantial portions of the Software.
+
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
+ OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
+ IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
+ DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
+ OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
+ USE OR OTHER DEALINGS IN THE SOFTWARE.
+
+ The tokenization of Chinese text in this script contains two steps:
+ separate each Chinese characters (by utf-8 encoding);
+ tokenize the non Chinese part (following the mteval script).
+ Author: Shujian Huang huangsj@nju.edu.cn
+
+ :param sentence: input sentence
+ :return: tokenized sentence
+ """
def is_chinese_char(uchar):
"""
- :param uchar: input char in unicode
- :return: whether the input char is a Chinese character.
- """
- if "\u3400" <= uchar <= "\u4db5":
+ :param uchar: input char in unicode
+ :return: whether the input char is a Chinese character.
+ """
+ if '\u3400' <= uchar <= '\u4db5':
return True
- elif "\u4e00" <= uchar <= "\u9fa5":
+ elif '\u4e00' <= uchar <= '\u9fa5':
return True
- elif "\u9fa6" <= uchar <= "\u9fbb":
+ elif '\u9fa6' <= uchar <= '\u9fbb':
return True
- elif "\uf900" <= uchar <= "\ufa2d":
+ elif '\uf900' <= uchar <= '\ufa2d':
return True
- elif "\ufa30" <= uchar <= "\ufa6a":
+ elif '\ufa30' <= uchar <= '\ufa6a':
return True
- elif "\ufa70" <= uchar <= "\ufad9":
+ elif '\ufa70' <= uchar <= '\ufad9':
return True
- elif "\u20000" <= uchar <= "\u2a6d6":
+ elif '\u20000' <= uchar <= '\u2a6d6':
return True
- elif "\u2f800" <= uchar <= "\u2fa1d":
+ elif '\u2f800' <= uchar <= '\u2fa1d':
return True
- elif "\uff00" <= uchar <= "\uffef":
+ elif '\uff00' <= uchar <= '\uffef':
return True
- elif "\u2e80" <= uchar <= "\u2eff":
+ elif '\u2e80' <= uchar <= '\u2eff':
return True
- elif "\u3000" <= uchar <= "\u303f":
+ elif '\u3000' <= uchar <= '\u303f':
return True
- elif "\u31c0" <= uchar <= "\u31ef":
+ elif '\u31c0' <= uchar <= '\u31ef':
return True
- elif "\u2f00" <= uchar <= "\u2fdf":
+ elif '\u2f00' <= uchar <= '\u2fdf':
return True
- elif "\u2ff0" <= uchar <= "\u2fff":
+ elif '\u2ff0' <= uchar <= '\u2fff':
return True
- elif "\u3100" <= uchar <= "\u312f":
+ elif '\u3100' <= uchar <= '\u312f':
return True
- elif "\u31a0" <= uchar <= "\u31bf":
+ elif '\u31a0' <= uchar <= '\u31bf':
return True
- elif "\ufe10" <= uchar <= "\ufe1f":
+ elif '\ufe10' <= uchar <= '\ufe1f':
return True
- elif "\ufe30" <= uchar <= "\ufe4f":
+ elif '\ufe30' <= uchar <= '\ufe4f':
return True
- elif "\u2600" <= uchar <= "\u26ff":
+ elif '\u2600' <= uchar <= '\u26ff':
return True
- elif "\u2700" <= uchar <= "\u27bf":
+ elif '\u2700' <= uchar <= '\u27bf':
return True
- elif "\u3200" <= uchar <= "\u32ff":
+ elif '\u3200' <= uchar <= '\u32ff':
return True
- elif "\u3300" <= uchar <= "\u33ff":
+ elif '\u3300' <= uchar <= '\u33ff':
return True
return False
sentence = sentence.strip()
- sentence_in_chars = ""
+ sentence_in_chars = ''
for char in sentence:
if is_chinese_char(char):
- sentence_in_chars += " "
+ sentence_in_chars += ' '
sentence_in_chars += char
- sentence_in_chars += " "
+ sentence_in_chars += ' '
else:
sentence_in_chars += char
sentence = sentence_in_chars
@@ -245,10 +249,10 @@ def is_chinese_char(uchar):
TOKENIZERS = {
- '13a': tokenize_13a,
- 'intl': tokenize_v14_international,
- 'zh': tokenize_zh,
- 'none': lambda x: x,
+ '13a': tokenize_13a,
+ 'intl': tokenize_v14_international,
+ 'zh': tokenize_zh,
+ 'none': lambda x: x,
}
DEFAULT_TOKENIZER = '13a'
@@ -256,16 +260,16 @@ def is_chinese_char(uchar):
def extract_ngrams(line, min_order=1, max_order=NGRAM_ORDER) -> Counter:
"""Extracts all the ngrams (1 <= n <= NGRAM_ORDER) from a sequence of tokens.
- :param line: a segment containing a sequence of words
- :param max_order: collect n-grams from 1<=n<=max
- :return: a dictionary containing ngrams and counts
- """
+ :param line: a segment containing a sequence of words
+ :param max_order: collect n-grams from 1<=n<=max
+ :return: a dictionary containing ngrams and counts
+ """
ngrams = Counter()
tokens = line.split()
for n in range(min_order, max_order + 1):
for i in range(0, len(tokens) - n + 1):
- ngram = ' '.join(tokens[i:i + n])
+ ngram = ' '.join(tokens[i : i + n])
ngrams[ngram] += 1
return ngrams
@@ -293,41 +297,44 @@ def ref_stats(output, refs):
return ngrams, closest_diff, closest_len
-BLEU = namedtuple('BLE',
- 'score, counts, totals, precisions, bp, sys_len, ref_len')
+BLEU = namedtuple(
+ 'BLE', 'score, counts, totals, precisions, bp, sys_len, ref_len'
+)
-def compute_bleu(correct: List[int],
- total: List[int],
- sys_len: int,
- ref_len: int,
- smooth_method='none',
- smooth_value=SMOOTH_VALUE_DEFAULT,
- use_effective_order=False) -> BLEU:
+def compute_bleu(
+ correct: List[int],
+ total: List[int],
+ sys_len: int,
+ ref_len: int,
+ smooth_method='none',
+ smooth_value=SMOOTH_VALUE_DEFAULT,
+ use_effective_order=False,
+) -> BLEU:
"""Computes BLEU score from its sufficient statistics. Adds smoothing.
- Smoothing methods (citing "A Systematic Comparison of Smoothing Techniques
- for Sentence-Level BLEU", Boxing Chen and Colin Cherry,
- WMT 2014: http://aclweb.org/anthology/W14-3346)
-
- - exp: NIST smoothing method (Method 3)
- - floor: Method 1
- - add-k: Method 2 (generalizing Lin and Och, 2004)
- - none: do nothing.
-
- :param correct: List of counts of correct ngrams, 1 <= n <= NGRAM_ORDER
- :param total: List of counts of total ngrams, 1 <= n <= NGRAM_ORDER
- :param sys_len: The cumulative system length
- :param ref_len: The cumulative reference length
- :param smooth: The smoothing method to use
- :param smooth_value: The smoothing value added, if smooth is 'floor'
- :param use_effective_order: Use effective order.
- :return: A BLEU object with the score (100-based) and other statistics.
- """
+ Smoothing methods (citing "A Systematic Comparison of Smoothing Techniques
+ for Sentence-Level BLEU", Boxing Chen and Colin Cherry,
+ WMT 2014: http://aclweb.org/anthology/W14-3346)
+
+ - exp: NIST smoothing method (Method 3)
+ - floor: Method 1
+ - add-k: Method 2 (generalizing Lin and Och, 2004)
+ - none: do nothing.
+
+ :param correct: List of counts of correct ngrams, 1 <= n <= NGRAM_ORDER
+ :param total: List of counts of total ngrams, 1 <= n <= NGRAM_ORDER
+ :param sys_len: The cumulative system length
+ :param ref_len: The cumulative reference length
+ :param smooth: The smoothing method to use
+ :param smooth_value: The smoothing value added, if smooth is 'floor'
+ :param use_effective_order: Use effective order.
+ :return: A BLEU object with the score (100-based) and other statistics.
+ """
precisions = [0 for x in range(NGRAM_ORDER)]
- smooth_mteval = 1.
+ smooth_mteval = 1.0
effective_order = NGRAM_ORDER
for n in range(NGRAM_ORDER):
if smooth_method == 'add-k' and n > 1:
@@ -342,11 +349,11 @@ def compute_bleu(correct: List[int],
if correct[n] == 0:
if smooth_method == 'exp':
smooth_mteval *= 2
- precisions[n] = 100. / (smooth_mteval * total[n])
+ precisions[n] = 100.0 / (smooth_mteval * total[n])
elif smooth_method == 'floor':
- precisions[n] = 100. * smooth_value / total[n]
+ precisions[n] = 100.0 * smooth_value / total[n]
else:
- precisions[n] = 100. * correct[n] / total[n]
+ precisions[n] = 100.0 * correct[n] / total[n]
# If the system guesses no i-grams, 1 <= i <= NGRAM_ORDER, the BLEU
# score is 0 (technically undefined). This is a problem for sentence-level
@@ -360,20 +367,24 @@ def compute_bleu(correct: List[int],
brevity_penalty = math.exp(1 - ref_len / sys_len) if sys_len > 0 else 0.0
bleu = brevity_penalty * math.exp(
- sum(map(my_log, precisions[:effective_order])) / effective_order)
+ sum(map(my_log, precisions[:effective_order])) / effective_order
+ )
return BLEU._make(
- [bleu, correct, total, precisions, brevity_penalty, sys_len, ref_len])
-
-
-def corpus_bleu(sys_stream: Sequence[str],
- ref_streams: Sequence[str],
- smooth_method: str = 'exp',
- smooth_value: float = 0.0,
- force: bool = False,
- lowercase: bool = False,
- tokenize: str = '13a',
- use_effective_order: bool = False) -> BLEU:
+ [bleu, correct, total, precisions, brevity_penalty, sys_len, ref_len]
+ )
+
+
+def corpus_bleu(
+ sys_stream: Sequence[str],
+ ref_streams: Sequence[str],
+ smooth_method: str = 'exp',
+ smooth_value: float = 0.0,
+ force: bool = False,
+ lowercase: bool = False,
+ tokenize: str = '13a',
+ use_effective_order: bool = False,
+) -> BLEU:
"""Produces BLEU scores along with its sufficient statistics from a source
against one or more references.
:param sys_stream: The system stream (a sequence of segments).
@@ -414,13 +425,16 @@ def corpus_bleu(sys_stream: Sequence[str],
tokenized_count += 1
if tokenized_count == 100:
+ logging.warning("That's 100 lines that end in a tokenized period ('.')")
+ logging.warning(
+ 'It looks like you forgot to detokenize your test '
+ 'data, which may hurt your score.'
+ )
logging.warning(
- 'That\'s 100 lines that end in a tokenized period (\'.\')')
- logging.warning('It looks like you forgot to detokenize your test '
- 'data, which may hurt your score.')
- logging.warning('If you insist your data is detokenized, '
- 'or don\'t care, you can suppress this message with '
- '\'--force\'.')
+ 'If you insist your data is detokenized, '
+ "or don't care, you can suppress this message with "
+ "'--force'."
+ )
output, *refs = [TOKENIZERS[tokenize](x.rstrip()) for x in lines]
@@ -453,10 +467,11 @@ def corpus_bleu(sys_stream: Sequence[str],
total = total.cpu().numpy().tolist()
return compute_bleu(
- correct,
- total,
- sys_len,
- ref_len,
- smooth_method=smooth_method,
- smooth_value=smooth_value,
- use_effective_order=use_effective_order)
+ correct,
+ total,
+ sys_len,
+ ref_len,
+ smooth_method=smooth_method,
+ smooth_value=smooth_value,
+ use_effective_order=use_effective_order,
+ )
diff --git a/algoperf/workloads/wmt/input_pipeline.py b/algoperf/workloads/wmt/input_pipeline.py
index d743b43b0..1df1dfc55 100644
--- a/algoperf/workloads/wmt/input_pipeline.py
+++ b/algoperf/workloads/wmt/input_pipeline.py
@@ -1,4 +1,5 @@
"""Input pipeline for a WMT dataset."""
+
import functools
import os
from typing import Dict, List, Optional, Union
@@ -16,10 +17,10 @@
Features = Dict[str, tf.Tensor]
TFDS_SPLIT_NAME = {
- 'train': 'train',
- 'eval_train': 'train',
- 'validation': 'validation',
- 'test': 'test',
+ 'train': 'train',
+ 'eval_train': 'train',
+ 'validation': 'validation',
+ 'test': 'test',
}
@@ -31,9 +32,11 @@ def normalize_feature_names(ds_info, features: Features) -> Features:
return features
-def pack_dataset(dataset: tf.data.Dataset,
- key2length: Union[int, Dict[str, int]],
- keys: Optional[List[str]] = None) -> tf.data.Dataset:
+def pack_dataset(
+ dataset: tf.data.Dataset,
+ key2length: Union[int, Dict[str, int]],
+ keys: Optional[List[str]] = None,
+) -> tf.data.Dataset:
"""Creates a 'packed' version of a dataset on-the-fly.
Adapted from the mesh-tf implementation.
This is meant to replace the irritation of having to create a separate
@@ -75,7 +78,8 @@ def pack_dataset(dataset: tf.data.Dataset,
for k in keys:
if k not in shapes:
raise ValueError(
- f'Key {k} not found in dataset. Available keys are {shapes.keys()}')
+ f'Key {k} not found in dataset. Available keys are {shapes.keys()}'
+ )
if not shapes[k].is_compatible_with(tf.TensorShape([None])):
raise ValueError('Tensors to be packed must be one-dimensional.')
# make sure that the length dictionary contains all keys as well as the
@@ -88,13 +92,15 @@ def pack_dataset(dataset: tf.data.Dataset,
# trim to length
dataset = dataset.map(
- lambda x: {k: x[k][:key2length[k]] for k in keys},
- num_parallel_calls=AUTOTUNE)
+ lambda x: {k: x[k][: key2length[k]] for k in keys},
+ num_parallel_calls=AUTOTUNE,
+ )
# Setting batch_size=length ensures that the concatenated sequences (if they
# have length >=1) are sufficient to fill at least one packed example.
batch_size = max(key2length.values())
dataset = dataset.padded_batch(
- batch_size, padded_shapes={k: [-1] for k in keys})
+ batch_size, padded_shapes={k: [-1] for k in keys}
+ )
dataset = _pack_with_tf_ops(dataset, keys, key2length)
# Set the Tensor shapes correctly since they get lost in the process.
@@ -104,9 +110,9 @@ def my_fn(x):
return dataset.map(my_fn, num_parallel_calls=AUTOTUNE)
-def _pack_with_tf_ops(dataset: tf.data.Dataset,
- keys: List[str],
- key2length: Dict[str, int]) -> tf.data.Dataset:
+def _pack_with_tf_ops(
+ dataset: tf.data.Dataset, keys: List[str], key2length: Dict[str, int]
+) -> tf.data.Dataset:
"""Helper-function for packing a dataset which has already been batched.
Helper for pack_dataset() Uses tf.while_loop.
Args:
@@ -127,8 +133,9 @@ def write_packed_example(partial, outputs):
new_outputs = {}
for k in keys_etc:
new_outputs[k] = outputs[k].write(
- outputs[k].size(),
- tf.pad(partial[k], [[0, key2length[k] - tf.size(partial[k])]]))
+ outputs[k].size(),
+ tf.pad(partial[k], [[0, key2length[k] - tf.size(partial[k])]]),
+ )
return new_partial, new_outputs
def map_fn(x):
@@ -146,9 +153,11 @@ def map_fn(x):
outputs = {}
for k in keys:
outputs[k] = tf.TensorArray(
- tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]])
+ tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]]
+ )
outputs[k + '_position'] = tf.TensorArray(
- tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]])
+ tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]]
+ )
def body_fn(i, partial, outputs):
"""Body function for while_loop.
@@ -163,13 +172,15 @@ def body_fn(i, partial, outputs):
one_example = {}
for k in keys:
val = tf.cast(x[k][i], tf.int32)
- val = val[:tf.reduce_sum(tf.cast(tf.not_equal(val, 0), tf.int32))]
+ val = val[: tf.reduce_sum(tf.cast(tf.not_equal(val, 0), tf.int32))]
one_example[k] = val
for k in keys:
can_append = tf.logical_and(
- can_append,
- tf.less_equal(
- tf.size(partial[k]) + tf.size(one_example[k]), key2length[k]))
+ can_append,
+ tf.less_equal(
+ tf.size(partial[k]) + tf.size(one_example[k]), key2length[k]
+ ),
+ )
def false_fn():
return write_packed_example(partial, outputs)
@@ -180,49 +191,51 @@ def true_fn():
partial, outputs = tf.cond(can_append, true_fn, false_fn)
new_partial = {}
for k in keys:
- new_seq = one_example[k][:key2length[k]]
+ new_seq = one_example[k][: key2length[k]]
new_seq_len = tf.size(new_seq)
new_partial[k] = tf.concat([partial[k], new_seq], 0)
new_partial[k + '_position'] = tf.concat(
- [partial[k + '_position'], tf.range(new_seq_len)], 0)
+ [partial[k + '_position'], tf.range(new_seq_len)], 0
+ )
partial = new_partial
return i + 1, partial, outputs
# For loop over all examples in the batch.
i, partial, outputs = tf.while_loop(
- cond=lambda *_: True,
- body=body_fn,
- loop_vars=(i, partial, outputs),
- shape_invariants=(
- tf.TensorShape([]),
- {k: tf.TensorShape([None]) for k in keys_etc},
- {k: tf.TensorShape(None) for k in keys_etc},
- ),
- maximum_iterations=dynamic_batch_size)
+ cond=lambda *_: True,
+ body=body_fn,
+ loop_vars=(i, partial, outputs),
+ shape_invariants=(
+ tf.TensorShape([]),
+ {k: tf.TensorShape([None]) for k in keys_etc},
+ {k: tf.TensorShape(None) for k in keys_etc},
+ ),
+ maximum_iterations=dynamic_batch_size,
+ )
_, outputs = write_packed_example(partial, outputs)
packed = {k: outputs[k].stack() for k in keys_etc}
for k in keys:
- packed[k + '_segmentation'] = (
- tf.cumsum(
- tf.cast(tf.equal(packed[k + '_position'], 0), tf.int32), axis=1) *
- tf.cast(tf.not_equal(packed[k], 0), tf.int32))
+ packed[k + '_segmentation'] = tf.cumsum(
+ tf.cast(tf.equal(packed[k + '_position'], 0), tf.int32), axis=1
+ ) * tf.cast(tf.not_equal(packed[k], 0), tf.int32)
return packed
dataset = dataset.map(map_fn, num_parallel_calls=AUTOTUNE)
return dataset.unbatch()
-def preprocess_wmt_data(dataset: tf.data.Dataset,
- data_rng,
- train: bool,
- shuffle: bool,
- shuffle_buffer_size: int = 1024,
- max_length: int = 256,
- global_batch_size: int = 128):
+def preprocess_wmt_data(
+ dataset: tf.data.Dataset,
+ data_rng,
+ train: bool,
+ shuffle: bool,
+ shuffle_buffer_size: int = 1024,
+ max_length: int = 256,
+ global_batch_size: int = 128,
+):
"""Shuffle and batch/pack the given dataset."""
def length_filter(max_len):
-
def filter_fn(x):
source, target = x['inputs'], x['targets']
l = tf.maximum(tf.shape(source)[0], tf.shape(target)[0])
@@ -242,24 +255,27 @@ def filter_fn(x):
dataset = dataset.batch(global_batch_size, drop_remainder=train)
else: # simple (static-shape) padded batching
dataset = dataset.padded_batch(
- global_batch_size,
- padded_shapes={'inputs': max_length, 'targets': max_length},
- padding_values={'inputs': 0, 'targets': 0},
- drop_remainder=False)
+ global_batch_size,
+ padded_shapes={'inputs': max_length, 'targets': max_length},
+ padding_values={'inputs': 0, 'targets': 0},
+ drop_remainder=False,
+ )
dataset = dataset.prefetch(AUTOTUNE)
return dataset
-def get_wmt_dataset(data_rng,
- split: str,
- data_dir: str,
- is_training: bool,
- vocab_size: int,
- global_batch_size: int,
- num_batches: Optional[int] = None,
- repeat_final_dataset: bool = False,
- vocab_path: Optional[str] = None):
+def get_wmt_dataset(
+ data_rng,
+ split: str,
+ data_dir: str,
+ is_training: bool,
+ vocab_size: int,
+ global_batch_size: int,
+ num_batches: Optional[int] = None,
+ repeat_final_dataset: bool = False,
+ vocab_path: Optional[str] = None,
+):
"""Load and return dataset of batched examples for use during training."""
if vocab_path is None:
vocab_path = os.path.join(data_dir, 'wmt_sentencepiece_model')
@@ -271,7 +287,8 @@ def get_wmt_dataset(data_rng,
dataset_builder = tfds.builder(ds_name, data_dir=data_dir)
ds = dataset_builder.as_dataset(
- split=TFDS_SPLIT_NAME[split], shuffle_files=False)
+ split=TFDS_SPLIT_NAME[split], shuffle_files=False
+ )
# Avoid creating too many threads when using PyTorch DDP.
if RANK != 0:
@@ -280,8 +297,9 @@ def get_wmt_dataset(data_rng,
ds = ds.with_options(options)
ds = ds.map(
- functools.partial(normalize_feature_names, dataset_builder.info),
- num_parallel_calls=AUTOTUNE)
+ functools.partial(normalize_feature_names, dataset_builder.info),
+ num_parallel_calls=AUTOTUNE,
+ )
# Load tf-text SentencePiece tokenizer.
sp_tokenizer = tokenizer.load_tokenizer(vocab_path=vocab_path)
@@ -289,12 +307,13 @@ def get_wmt_dataset(data_rng,
shuffle = split in ['train', 'eval_train']
ds = preprocess_wmt_data(
- ds,
- data_rng,
- train=is_training,
- shuffle=shuffle,
- global_batch_size=global_batch_size,
- max_length=256)
+ ds,
+ data_rng,
+ train=is_training,
+ shuffle=shuffle,
+ global_batch_size=global_batch_size,
+ max_length=256,
+ )
if num_batches:
ds = ds.take(num_batches)
@@ -303,9 +322,10 @@ def get_wmt_dataset(data_rng,
ds = ds.repeat()
ds = map(
- functools.partial(
- data_utils.shard_and_maybe_pad_np,
- global_batch_size=global_batch_size),
- ds)
+ functools.partial(
+ data_utils.shard_and_maybe_pad_np, global_batch_size=global_batch_size
+ ),
+ ds,
+ )
return ds, sp_tokenizer
diff --git a/algoperf/workloads/wmt/tokenizer.py b/algoperf/workloads/wmt/tokenizer.py
index 1f001e619..d94cab808 100644
--- a/algoperf/workloads/wmt/tokenizer.py
+++ b/algoperf/workloads/wmt/tokenizer.py
@@ -19,9 +19,9 @@
def _dump_chars_to_textfile(
- dataset: tf.data.Dataset,
- maxchars: int = int(1e7),
- data_keys=('inputs', 'targets')
+ dataset: tf.data.Dataset,
+ maxchars: int = int(1e7),
+ data_keys=('inputs', 'targets'),
) -> Tuple[str, int]:
"""Write part of a TFDS sentence dataset to lines in a text file.
@@ -36,7 +36,8 @@ def _dump_chars_to_textfile(
char_count = 0
ds_iter = dataset.as_numpy_iterator()
with tempfile.NamedTemporaryFile(
- delete=False, prefix='/tmp/ds_chars') as outfp:
+ delete=False, prefix='/tmp/ds_chars'
+ ) as outfp:
while char_count < maxchars:
example = next(ds_iter)
for k in data_keys:
@@ -46,14 +47,16 @@ def _dump_chars_to_textfile(
return outfp.name, char_count
-def _train_sentencepiece(dataset: tf.data.Dataset,
- *,
- vocab_size: int,
- maxchars: int = int(1e7),
- model_path: str,
- model_type: str = 'unigram',
- character_coverage: float = 1.0,
- data_keys=('inputs', 'targets')):
+def _train_sentencepiece(
+ dataset: tf.data.Dataset,
+ *,
+ vocab_size: int,
+ maxchars: int = int(1e7),
+ model_path: str,
+ model_type: str = 'unigram',
+ character_coverage: float = 1.0,
+ data_keys=('inputs', 'targets'),
+):
"""Train SentencePiece tokenizer from subset of tf dataset.
Args:
@@ -75,17 +78,21 @@ def _train_sentencepiece(dataset: tf.data.Dataset,
else:
abs_model_path = os.path.abspath(os.path.expanduser(model_path))
fname, _ = _dump_chars_to_textfile(
- dataset, maxchars=maxchars, data_keys=data_keys)
+ dataset, maxchars=maxchars, data_keys=data_keys
+ )
with tempfile.NamedTemporaryFile(
- delete=False, prefix='/tmp/sp_tmp') as model_fp:
+ delete=False, prefix='/tmp/sp_tmp'
+ ) as model_fp:
pass # we just want a prefix'd tmp-filename
- argstr = ' '.join([
+ argstr = ' '.join(
+ [
f'--input={fname}',
f'--vocab_size={vocab_size}',
f'--character_coverage={character_coverage}',
f'--model_prefix={model_fp.name}',
f'--model_type={model_type}',
- ])
+ ]
+ )
SentencePieceTrainer.Train(argstr)
if jax.process_index() == 0:
# Use an intermediate filename that is renamed to the target name to address
@@ -100,32 +107,38 @@ def _train_sentencepiece(dataset: tf.data.Dataset,
time.sleep(1)
-def _load_sentencepiece_tokenizer(model_path: str,
- add_bos: bool = False,
- add_eos: bool = True,
- reverse: bool = False):
+def _load_sentencepiece_tokenizer(
+ model_path: str,
+ add_bos: bool = False,
+ add_eos: bool = True,
+ reverse: bool = False,
+):
"""Load a tf-text SentencePiece tokenizer from given model filepath."""
with tf.io.gfile.GFile(model_path, 'rb') as model_fp:
sp_model = model_fp.read()
sp_tokenizer = tftxt.SentencepieceTokenizer(
- model=sp_model, add_bos=add_bos, add_eos=add_eos, reverse=reverse)
+ model=sp_model, add_bos=add_bos, add_eos=add_eos, reverse=reverse
+ )
return sp_tokenizer
-def train_tokenizer(dataset: tf.data.Dataset,
- *,
- vocab_path: str,
- vocab_size: int,
- max_corpus_chars: int,
- data_keys: Tuple[str, str] = ('inputs', 'targets')):
+def train_tokenizer(
+ dataset: tf.data.Dataset,
+ *,
+ vocab_path: str,
+ vocab_size: int,
+ max_corpus_chars: int,
+ data_keys: Tuple[str, str] = ('inputs', 'targets'),
+):
"""Trains a tokenizer from `dataset`."""
logging.info('Building SentencePiece vocab from data.')
_train_sentencepiece(
- dataset,
- vocab_size=vocab_size,
- maxchars=max_corpus_chars,
- model_path=vocab_path,
- data_keys=data_keys)
+ dataset,
+ vocab_size=vocab_size,
+ maxchars=max_corpus_chars,
+ model_path=vocab_path,
+ data_keys=data_keys,
+ )
def load_tokenizer(vocab_path: str):
@@ -135,7 +148,6 @@ def load_tokenizer(vocab_path: str):
@dataclasses.dataclass
class TokenizeOp:
-
sp_tokenizer: Any
data_keys: Iterable[str] = ('inputs', 'targets')
diff --git a/algoperf/workloads/wmt/wmt_jax/decode.py b/algoperf/workloads/wmt/wmt_jax/decode.py
index dfead5918..b5f5f1099 100644
--- a/algoperf/workloads/wmt/wmt_jax/decode.py
+++ b/algoperf/workloads/wmt/wmt_jax/decode.py
@@ -78,8 +78,9 @@ def gather_beams(nested, beam_indices, batch_size, new_beam_size):
[batch_size, old_beam_size, ...] --> [batch_size, new_beam_size, ...]
"""
batch_indices = jnp.reshape(
- jnp.arange(batch_size * new_beam_size) // new_beam_size,
- (batch_size, new_beam_size))
+ jnp.arange(batch_size * new_beam_size) // new_beam_size,
+ (batch_size, new_beam_size),
+ )
def gather_fn(x):
if x.ndim < 2: # ignore scalars (e.g. cache index)
@@ -114,6 +115,7 @@ def gather_topk_beams(nested, score_or_log_prob, batch_size, new_beam_size):
@flax.struct.dataclass
class BeamState:
"""Holds beam search state data."""
+
# The position of the decoding loop in the length dimension.
cur_index: jax.Array # scalar int32: current decoded length index
# The active sequence log probabilities and finished sequence scores.
@@ -133,7 +135,8 @@ def beam_init(batch_size, beam_size, max_decode_len, cache):
"""Initializes the beam search state data structure."""
cur_index0 = jnp.array(0)
live_logprobs0 = jnp.tile(
- jnp.array([0.0] + [NEG_INF] * (beam_size - 1)), [batch_size, 1])
+ jnp.array([0.0] + [NEG_INF] * (beam_size - 1)), [batch_size, 1]
+ )
finished_scores0 = jnp.ones((batch_size, beam_size)) * NEG_INF
live_seqs0 = jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32)
finished_seqs0 = jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32)
@@ -141,25 +144,28 @@ def beam_init(batch_size, beam_size, max_decode_len, cache):
# add beam dimension to attention cache pytree elements
beam_cache0 = jax.tree.map(lambda x: add_beam_dim(x, beam_size), cache)
return BeamState(
- cur_index=cur_index0,
- live_logprobs=live_logprobs0,
- finished_scores=finished_scores0,
- live_seqs=live_seqs0,
- finished_seqs=finished_seqs0,
- finished_flags=finished_flags0,
- cache=beam_cache0)
+ cur_index=cur_index0,
+ live_logprobs=live_logprobs0,
+ finished_scores=finished_scores0,
+ live_seqs=live_seqs0,
+ finished_seqs=finished_seqs0,
+ finished_flags=finished_flags0,
+ cache=beam_cache0,
+ )
# Beam search routine:
-def beam_search(inputs,
- cache,
- tokens_to_logits,
- beam_size=4,
- alpha=0.6,
- eos_id=EOS_ID,
- max_decode_len=None):
+def beam_search(
+ inputs,
+ cache,
+ tokens_to_logits,
+ beam_size=4,
+ alpha=0.6,
+ eos_id=EOS_ID,
+ max_decode_len=None,
+):
"""Beam search for transformer machine translation.
Args:
@@ -185,10 +191,9 @@ def beam_search(inputs,
end_marker = jnp.array(eos_id)
# initialize beam search state
- beam_search_init_state = beam_init(batch_size,
- beam_size,
- max_decode_len,
- cache)
+ beam_search_init_state = beam_init(
+ batch_size, beam_size, max_decode_len, cache
+ )
def beam_search_loop_cond_fn(state):
"""Beam search loop termination condition."""
@@ -201,11 +206,12 @@ def beam_search_loop_cond_fn(state):
best_live_scores = state.live_logprobs[:, -1:] / min_brevity_penalty
# Get the worst scores from finished sequences.
worst_finished_scores = jnp.min(
- state.finished_scores, axis=1, keepdims=True)
+ state.finished_scores, axis=1, keepdims=True
+ )
# Mask out scores from slots without any actual finished sequences.
- worst_finished_scores = jnp.where(state.finished_flags,
- worst_finished_scores,
- NEG_INF)
+ worst_finished_scores = jnp.where(
+ state.finished_flags, worst_finished_scores, NEG_INF
+ )
# If no best possible live score is better than current worst finished
# scores, the search cannot improve the finished set further.
search_terminated = jnp.all(worst_finished_scores > best_live_scores)
@@ -221,8 +227,10 @@ def beam_search_loop_body_fn(state):
# dimension for feeding into the model.
# --> [batch * beam, 1]
flat_ids = flatten_beam_dim(
- lax.dynamic_slice(state.live_seqs, (0, 0, state.cur_index),
- (batch_size, beam_size, 1)))
+ lax.dynamic_slice(
+ state.live_seqs, (0, 0, state.cur_index), (batch_size, beam_size, 1)
+ )
+ )
# Flatten beam dimension into batch to be compatible with model.
# {[batch, beam, ...], ...} --> {[batch * beam, ...], ...}
flat_cache = jax.tree.map(flatten_beam_dim, state.cache)
@@ -237,14 +245,16 @@ def beam_search_loop_body_fn(state):
# Unflatten beam dimension in attention cache arrays
# {[batch * beam, ...], ...} --> {[batch, beam, ...], ...}
new_cache = jax.tree.map(
- lambda x: unflatten_beam_dim(x, batch_size, beam_size), new_flat_cache)
+ lambda x: unflatten_beam_dim(x, batch_size, beam_size), new_flat_cache
+ )
# Gather log probabilities from logits
candidate_log_probs = jax.nn.log_softmax(logits)
# Add new logprobs to existing prefix logprobs.
# --> [batch, beam, vocab]
- log_probs = (
- candidate_log_probs + jnp.expand_dims(state.live_logprobs, axis=2))
+ log_probs = candidate_log_probs + jnp.expand_dims(
+ state.live_logprobs, axis=2
+ )
# We'll need the vocab size, gather it from the log probability dimension.
vocab_size = log_probs.shape[2]
@@ -264,10 +274,9 @@ def beam_search_loop_body_fn(state):
topk_beam_indices = topk_indices // vocab_size
# Gather 2*k top beams.
# --> [batch, 2*beams, length]
- topk_seq = gather_beams(state.live_seqs,
- topk_beam_indices,
- batch_size,
- beams_to_keep)
+ topk_seq = gather_beams(
+ state.live_seqs, topk_beam_indices, batch_size, beams_to_keep
+ )
# Append the most probable 2*K token IDs to the top 2*K sequences
# Recover token id by modulo division and expand Id array for broadcasting.
@@ -275,13 +284,14 @@ def beam_search_loop_body_fn(state):
topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2)
# Update sequences for the 2*K top-k new sequences.
# --> [batch, 2*beams, length]
- topk_seq = lax.dynamic_update_slice(topk_seq,
- topk_ids, (0, 0, state.cur_index + 1))
+ topk_seq = lax.dynamic_update_slice(
+ topk_seq, topk_ids, (0, 0, state.cur_index + 1)
+ )
# Update LIVE (in-progress) sequences:
# Did any of these sequences reach an end marker?
# --> [batch, 2*beams]
- newly_finished = (topk_seq[:, :, state.cur_index + 1] == end_marker)
+ newly_finished = topk_seq[:, :, state.cur_index + 1] == end_marker
# To prevent these newly finished sequences from being added to the LIVE
# set of active beam search sequences, set their log probs to a very large
# negative value.
@@ -292,22 +302,20 @@ def beam_search_loop_body_fn(state):
new_topk_indices = jnp.flip(new_topk_indices, axis=1)
# Gather the top k beams (from top 2*k beams).
# --> [batch, beams, length], [batch, beams]
- top_alive_seq, top_alive_log_probs = gather_beams([topk_seq, new_log_probs],
- new_topk_indices,
- batch_size, beam_size)
+ top_alive_seq, top_alive_log_probs = gather_beams(
+ [topk_seq, new_log_probs], new_topk_indices, batch_size, beam_size
+ )
# Determine the top k beam indices from the original set of all beams.
# --> [batch, beams]
- top_alive_indices = gather_beams(topk_beam_indices,
- new_topk_indices,
- batch_size,
- beam_size)
+ top_alive_indices = gather_beams(
+ topk_beam_indices, new_topk_indices, batch_size, beam_size
+ )
# With these, gather the top k beam-associated caches.
# --> {[batch, beams, ...], ...}
- top_alive_cache = gather_beams(new_cache,
- top_alive_indices,
- batch_size,
- beam_size)
+ top_alive_cache = gather_beams(
+ new_cache, top_alive_indices, batch_size, beam_size
+ )
# Update FINISHED (reached end of sentence) sequences:
# Calculate new seq scores from log probabilities.
@@ -320,42 +328,54 @@ def beam_search_loop_body_fn(state):
# new finished sequence scores to existing finished scores and select the
# best from the new set of beams.
finished_seqs = jnp.concatenate( # --> [batch, 3*beams, length]
- [state.finished_seqs, topk_seq],
- axis=1)
+ [state.finished_seqs, topk_seq], axis=1
+ )
finished_scores = jnp.concatenate( # --> [batch, 3*beams]
- [state.finished_scores, new_scores], axis=1)
+ [state.finished_scores, new_scores], axis=1
+ )
finished_flags = jnp.concatenate( # --> [batch, 3*beams]
- [state.finished_flags, newly_finished], axis=1)
+ [state.finished_flags, newly_finished], axis=1
+ )
# --> [batch, beams, length], [batch, beams], [batch, beams]
top_finished_seq, top_finished_scores, top_finished_flags = (
- gather_topk_beams([finished_seqs, finished_scores, finished_flags],
- finished_scores, batch_size, beam_size))
+ gather_topk_beams(
+ [finished_seqs, finished_scores, finished_flags],
+ finished_scores,
+ batch_size,
+ beam_size,
+ )
+ )
return BeamState(
- cur_index=state.cur_index + 1,
- live_logprobs=top_alive_log_probs,
- finished_scores=top_finished_scores,
- live_seqs=top_alive_seq,
- finished_seqs=top_finished_seq,
- finished_flags=top_finished_flags,
- cache=top_alive_cache)
+ cur_index=state.cur_index + 1,
+ live_logprobs=top_alive_log_probs,
+ finished_scores=top_finished_scores,
+ live_seqs=top_alive_seq,
+ finished_seqs=top_finished_seq,
+ finished_flags=top_finished_flags,
+ cache=top_alive_cache,
+ )
# Run while loop and get final beam search state.
- final_state = lax.while_loop(beam_search_loop_cond_fn,
- beam_search_loop_body_fn,
- beam_search_init_state)
+ final_state = lax.while_loop(
+ beam_search_loop_cond_fn, beam_search_loop_body_fn, beam_search_init_state
+ )
# Account for the edge-case where there are no finished sequences for a
# particular batch item. If so, return live sequences for that batch item.
# --> [batch]
none_finished = jnp.any(final_state.finished_flags, axis=1)
# --> [batch, beams, length]
- finished_seqs = jnp.where(none_finished[:, None, None],
- final_state.finished_seqs,
- final_state.live_seqs)
+ finished_seqs = jnp.where(
+ none_finished[:, None, None],
+ final_state.finished_seqs,
+ final_state.live_seqs,
+ )
# --> [batch, beams]
- finished_scores = jnp.where(none_finished[:, None],
- final_state.finished_scores,
- final_state.live_logprobs)
+ finished_scores = jnp.where(
+ none_finished[:, None],
+ final_state.finished_scores,
+ final_state.live_logprobs,
+ )
return finished_seqs, finished_scores
diff --git a/algoperf/workloads/wmt/wmt_jax/models.py b/algoperf/workloads/wmt/wmt_jax/models.py
index 1147eb34b..81f2ece4c 100644
--- a/algoperf/workloads/wmt/wmt_jax/models.py
+++ b/algoperf/workloads/wmt/wmt_jax/models.py
@@ -5,11 +5,11 @@
from typing import Any, Callable, Optional
+import jax.numpy as jnp
+import numpy as np
from flax import linen as nn
from flax import struct
from jax import lax
-import jax.numpy as jnp
-import numpy as np
from algoperf.jax_utils import Dropout
@@ -19,6 +19,7 @@
@struct.dataclass
class TransformerConfig:
"""Global hyperparameters used to minimize obnoxious kwarg plumbing."""
+
share_embeddings: bool = True
dtype: Any = jnp.float32
vocab_size: int = 32000
@@ -44,7 +45,8 @@ def shift_right(x, axis=1):
pad_widths = [(0, 0)] * len(x.shape)
pad_widths[axis] = (1, 0)
padded = jnp.pad(
- x, pad_widths, mode='constant', constant_values=x.dtype.type(0))
+ x, pad_widths, mode='constant', constant_values=x.dtype.type(0)
+ )
return padded[:, :-1]
@@ -68,8 +70,8 @@ def init(key, shape, dtype=np.float32):
position = np.arange(0, max_len)[:, np.newaxis]
scale_factor = -np.log(max_scale / min_scale) / (d_feature // 2 - 1)
div_term = min_scale * np.exp(np.arange(0, d_feature // 2) * scale_factor)
- pe[:, :d_feature // 2] = np.sin(position * div_term)
- pe[:, d_feature // 2:2 * (d_feature // 2)] = np.cos(position * div_term)
+ pe[:, : d_feature // 2] = np.sin(position * div_term)
+ pe[:, d_feature // 2 : 2 * (d_feature // 2)] = np.cos(position * div_term)
pe = pe[np.newaxis, :, :] # [1, max_len, d_feature]
return jnp.array(pe)
@@ -83,6 +85,7 @@ class AddPositionEmbs(nn.Module):
config: TransformerConfig dataclass containing hyperparameters.
decode: whether to run in single-position autoregressive mode.
"""
+
config: TransformerConfig
decode: bool = False
@@ -103,27 +106,28 @@ def __call__(self, inputs, inputs_positions=None):
"""
cfg = self.config
# inputs.shape is (batch_size, seq_len, emb_dim)
- assert inputs.ndim == 3, ('Number of dimensions should be 3,'
- f' but it is: {inputs.ndim}')
+ assert inputs.ndim == 3, (
+ f'Number of dimensions should be 3, but it is: {inputs.ndim}'
+ )
length = inputs.shape[1]
pos_emb_shape = (1, cfg.max_len, inputs.shape[-1])
if cfg.posemb_init is None:
# Use a fixed (non-learned) sinusoidal position embedding.
- pos_embedding = sinusoidal_init(max_len=cfg.max_len)(None,
- pos_emb_shape,
- None)
+ pos_embedding = sinusoidal_init(max_len=cfg.max_len)(
+ None, pos_emb_shape, None
+ )
else:
- pos_embedding = self.param('pos_embedding',
- cfg.posemb_init,
- pos_emb_shape)
+ pos_embedding = self.param(
+ 'pos_embedding', cfg.posemb_init, pos_emb_shape
+ )
pe = pos_embedding[:, :length, :]
# We use a cache position index for tracking decoding position.
if self.decode:
is_initialized = self.has_variable('cache', 'cache_index')
- cache_index = self.variable('cache',
- 'cache_index',
- lambda: jnp.array(0, dtype=jnp.uint32))
+ cache_index = self.variable(
+ 'cache', 'cache_index', lambda: jnp.array(0, dtype=jnp.uint32)
+ )
if is_initialized:
i = cache_index.value
cache_index.value = i + 1
@@ -140,10 +144,10 @@ def __call__(self, inputs, inputs_positions=None):
class MlpBlock(nn.Module):
"""Transformer MLP / feed-forward block.
- Attributes:
- config: TransformerConfig dataclass containing hyperparameters.
- out_dim: optionally specify out dimension.
- """
+ Attributes:
+ config: TransformerConfig dataclass containing hyperparameters.
+ out_dim: optionally specify out dimension.
+ """
config: TransformerConfig
out_dim: Optional[int] = None
@@ -155,42 +159,41 @@ def __call__(self, inputs, dropout_rate=DROPOUT_RATE):
actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim
x = nn.Dense(
+ cfg.mlp_dim,
+ dtype=cfg.dtype,
+ kernel_init=cfg.kernel_init,
+ bias_init=cfg.bias_init,
+ )(inputs)
+ x = cfg.activation(x)
+ if cfg.glu:
+ y = nn.Dense(
cfg.mlp_dim,
dtype=cfg.dtype,
kernel_init=cfg.kernel_init,
bias_init=cfg.bias_init,
- )(
- inputs)
- x = cfg.activation(x)
- if cfg.glu:
- y = nn.Dense(
- cfg.mlp_dim,
- dtype=cfg.dtype,
- kernel_init=cfg.kernel_init,
- bias_init=cfg.bias_init,
- )(
- inputs)
+ )(inputs)
x = x * y
x = Dropout(rate=dropout_rate)(
- x, rate=dropout_rate, deterministic=cfg.deterministic)
+ x, rate=dropout_rate, deterministic=cfg.deterministic
+ )
output = nn.Dense(
- actual_out_dim,
- dtype=cfg.dtype,
- kernel_init=cfg.kernel_init,
- bias_init=cfg.bias_init,
- )(
- x)
+ actual_out_dim,
+ dtype=cfg.dtype,
+ kernel_init=cfg.kernel_init,
+ bias_init=cfg.bias_init,
+ )(x)
output = Dropout(rate=dropout_rate)(
- output, rate=dropout_rate, deterministic=cfg.deterministic)
+ output, rate=dropout_rate, deterministic=cfg.deterministic
+ )
return output
class Encoder1DBlock(nn.Module):
"""Transformer encoder layer.
- Attributes:
- config: TransformerConfig dataclass containing hyperparameters.
- """
+ Attributes:
+ config: TransformerConfig dataclass containing hyperparameters.
+ """
config: TransformerConfig
@@ -198,13 +201,13 @@ class Encoder1DBlock(nn.Module):
def __call__(self, inputs, encoder_mask=None, dropout_rate=DROPOUT_RATE):
"""Applies Encoder1DBlock module.
- Args:
- inputs: input data.
- encoder_mask: encoder self-attention mask.
+ Args:
+ inputs: input data.
+ encoder_mask: encoder self-attention mask.
- Returns:
- output after transformer encoder block.
- """
+ Returns:
+ output after transformer encoder block.
+ """
cfg = self.config
pre_ln = cfg.pre_ln
@@ -213,19 +216,20 @@ def __call__(self, inputs, encoder_mask=None, dropout_rate=DROPOUT_RATE):
assert inputs.ndim == 3
x = nn.LayerNorm(dtype=cfg.dtype)(inputs) if pre_ln else inputs
x = nn.MultiHeadDotProductAttention(
- num_heads=cfg.num_heads,
- dtype=cfg.dtype,
- qkv_features=cfg.qkv_dim,
- kernel_init=cfg.kernel_init,
- bias_init=cfg.bias_init,
- use_bias=False,
- broadcast_dropout=False,
- dropout_rate=dropout_rate,
- deterministic=cfg.deterministic,
+ num_heads=cfg.num_heads,
+ dtype=cfg.dtype,
+ qkv_features=cfg.qkv_dim,
+ kernel_init=cfg.kernel_init,
+ bias_init=cfg.bias_init,
+ use_bias=False,
+ broadcast_dropout=False,
+ dropout_rate=dropout_rate,
+ deterministic=cfg.deterministic,
)(cfg.attention_temp * x, x, mask=encoder_mask)
x = Dropout(rate=dropout_rate)(
- x, deterministic=cfg.deterministic, rate=dropout_rate)
+ x, deterministic=cfg.deterministic, rate=dropout_rate
+ )
x = x + inputs
if not pre_ln:
x = nn.LayerNorm(dtype=cfg.dtype)(x)
@@ -240,32 +244,32 @@ def __call__(self, inputs, encoder_mask=None, dropout_rate=DROPOUT_RATE):
class EncoderDecoder1DBlock(nn.Module):
"""Transformer encoder-decoder layer.
- Attributes:
- config: TransformerConfig dataclass containing hyperparameters.
- """
+ Attributes:
+ config: TransformerConfig dataclass containing hyperparameters.
+ """
config: TransformerConfig
@nn.compact
def __call__(
- self,
- targets,
- encoded,
- decoder_mask=None,
- encoder_decoder_mask=None,
- dropout_rate=DROPOUT_RATE,
+ self,
+ targets,
+ encoded,
+ decoder_mask=None,
+ encoder_decoder_mask=None,
+ dropout_rate=DROPOUT_RATE,
):
"""Applies EncoderDecoder1DBlock module.
- Args:
- targets: input data for decoder
- encoded: input data from encoder
- decoder_mask: decoder self-attention mask.
- encoder_decoder_mask: encoder-decoder attention mask.
+ Args:
+ targets: input data for decoder
+ encoded: input data from encoder
+ decoder_mask: decoder self-attention mask.
+ encoder_decoder_mask: encoder-decoder attention mask.
- Returns:
- output after transformer encoder-decoder block.
- """
+ Returns:
+ output after transformer encoder-decoder block.
+ """
cfg = self.config
pre_ln = cfg.pre_ln
@@ -275,20 +279,21 @@ def __call__(
x = nn.LayerNorm(dtype=cfg.dtype)(targets) if pre_ln else targets
x = nn.MultiHeadDotProductAttention(
- num_heads=cfg.num_heads,
- dtype=cfg.dtype,
- qkv_features=cfg.qkv_dim,
- kernel_init=cfg.kernel_init,
- bias_init=cfg.bias_init,
- use_bias=False,
- broadcast_dropout=False,
- dropout_rate=dropout_rate,
- deterministic=cfg.deterministic,
- decode=cfg.decode,
+ num_heads=cfg.num_heads,
+ dtype=cfg.dtype,
+ qkv_features=cfg.qkv_dim,
+ kernel_init=cfg.kernel_init,
+ bias_init=cfg.bias_init,
+ use_bias=False,
+ broadcast_dropout=False,
+ dropout_rate=dropout_rate,
+ deterministic=cfg.deterministic,
+ decode=cfg.decode,
)(cfg.attention_temp * x, x, mask=decoder_mask)
x = Dropout(rate=dropout_rate)(
- x, deterministic=cfg.deterministic, rate=dropout_rate)
+ x, deterministic=cfg.deterministic, rate=dropout_rate
+ )
x = x + targets
if not pre_ln:
x = nn.LayerNorm(dtype=cfg.dtype)(x)
@@ -296,19 +301,20 @@ def __call__(
# Encoder-Decoder block.
y = nn.LayerNorm(dtype=cfg.dtype)(x) if pre_ln else x
y = nn.MultiHeadDotProductAttention(
- num_heads=cfg.num_heads,
- dtype=cfg.dtype,
- qkv_features=cfg.qkv_dim,
- kernel_init=cfg.kernel_init,
- bias_init=cfg.bias_init,
- use_bias=False,
- broadcast_dropout=False,
- dropout_rate=dropout_rate,
- deterministic=cfg.deterministic,
+ num_heads=cfg.num_heads,
+ dtype=cfg.dtype,
+ qkv_features=cfg.qkv_dim,
+ kernel_init=cfg.kernel_init,
+ bias_init=cfg.bias_init,
+ use_bias=False,
+ broadcast_dropout=False,
+ dropout_rate=dropout_rate,
+ deterministic=cfg.deterministic,
)(cfg.attention_temp * y, encoded, mask=encoder_decoder_mask)
y = Dropout(rate=dropout_rate)(
- y, deterministic=cfg.deterministic, rate=dropout_rate)
+ y, deterministic=cfg.deterministic, rate=dropout_rate
+ )
y = y + x
if not pre_ln:
y = nn.LayerNorm(dtype=cfg.dtype)(y)
@@ -323,30 +329,32 @@ def __call__(
class Encoder(nn.Module):
"""Transformer Model Encoder for sequence to sequence translation.
- Attributes:
- config: TransformerConfig dataclass containing hyperparameters.
- shared_embedding: a shared embedding layer to use.
- """
+ Attributes:
+ config: TransformerConfig dataclass containing hyperparameters.
+ shared_embedding: a shared embedding layer to use.
+ """
config: TransformerConfig
shared_embedding: Any = None
@nn.compact
- def __call__(self,
- inputs,
- inputs_positions=None,
- encoder_mask=None,
- dropout_rate=DROPOUT_RATE):
+ def __call__(
+ self,
+ inputs,
+ inputs_positions=None,
+ encoder_mask=None,
+ dropout_rate=DROPOUT_RATE,
+ ):
"""Applies Transformer model on the inputs.
- Args:
- inputs: input data
- inputs_positions: input subsequence positions for packed examples.
- encoder_mask: decoder self-attention mask.
+ Args:
+ inputs: input data
+ inputs_positions: input subsequence positions for packed examples.
+ encoder_mask: decoder self-attention mask.
- Returns:
- output of a transformer encoder.
- """
+ Returns:
+ output of a transformer encoder.
+ """
cfg = self.config
assert inputs.ndim == 2 # (batch, len)
@@ -354,30 +362,34 @@ def __call__(self,
# Input Embedding
if self.shared_embedding is None:
input_embed = nn.Embed(
- num_embeddings=cfg.vocab_size,
- features=cfg.emb_dim,
- embedding_init=nn.initializers.normal(stddev=1.0),
+ num_embeddings=cfg.vocab_size,
+ features=cfg.emb_dim,
+ embedding_init=nn.initializers.normal(stddev=1.0),
)
else:
input_embed = self.shared_embedding
- x = inputs.astype("int32")
+ x = inputs.astype('int32')
x = input_embed(x)
- x = AddPositionEmbs(
- config=cfg, decode=False, name="posembed_input")(
- x, inputs_positions=inputs_positions)
+ x = AddPositionEmbs(config=cfg, decode=False, name='posembed_input')(
+ x, inputs_positions=inputs_positions
+ )
x = Dropout(rate=dropout_rate)(
- x, deterministic=cfg.deterministic, rate=dropout_rate)
+ x, deterministic=cfg.deterministic, rate=dropout_rate
+ )
x = x.astype(cfg.dtype)
# Input Encoder
for lyr in range(cfg.num_layers):
- x = Encoder1DBlock(
- config=cfg, name=f"encoderblock_{lyr}")(x, encoder_mask, dropout_rate)
+ x = Encoder1DBlock(config=cfg, name=f'encoderblock_{lyr}')(
+ x, encoder_mask, dropout_rate
+ )
encoded = (
- nn.LayerNorm(dtype=cfg.dtype, name="encoder_layernorm")(x)
- if cfg.pre_ln else x)
+ nn.LayerNorm(dtype=cfg.dtype, name='encoder_layernorm')(x)
+ if cfg.pre_ln
+ else x
+ )
return encoded
@@ -385,36 +397,36 @@ def __call__(self,
class Decoder(nn.Module):
"""Transformer Model Decoder for sequence to sequence translation.
- Attributes:
- config: TransformerConfig dataclass containing hyperparameters.
- shared_embedding: a shared embedding layer to use.
- """
+ Attributes:
+ config: TransformerConfig dataclass containing hyperparameters.
+ shared_embedding: a shared embedding layer to use.
+ """
config: TransformerConfig
shared_embedding: Any = None
@nn.compact
def __call__(
- self,
- encoded,
- targets,
- targets_positions=None,
- decoder_mask=None,
- encoder_decoder_mask=None,
- dropout_rate=DROPOUT_RATE,
+ self,
+ encoded,
+ targets,
+ targets_positions=None,
+ decoder_mask=None,
+ encoder_decoder_mask=None,
+ dropout_rate=DROPOUT_RATE,
):
"""Applies Transformer model on the inputs.
- Args:
- encoded: encoded input data from encoder.
- targets: target inputs.
- targets_positions: input subsequence positions for packed examples.
- decoder_mask: decoder self-attention mask.
- encoder_decoder_mask: encoder-decoder attention mask.
+ Args:
+ encoded: encoded input data from encoder.
+ targets: target inputs.
+ targets_positions: input subsequence positions for packed examples.
+ decoder_mask: decoder self-attention mask.
+ encoder_decoder_mask: encoder-decoder attention mask.
- Returns:
- output of a transformer decoder.
- """
+ Returns:
+ output of a transformer decoder.
+ """
cfg = self.config
assert encoded.ndim == 3 # (batch, len, depth)
@@ -423,38 +435,40 @@ def __call__(
# Target Embedding
if self.shared_embedding is None:
output_embed = nn.Embed(
- num_embeddings=cfg.vocab_size,
- features=cfg.emb_dim,
- embedding_init=nn.initializers.normal(stddev=1.0),
+ num_embeddings=cfg.vocab_size,
+ features=cfg.emb_dim,
+ embedding_init=nn.initializers.normal(stddev=1.0),
)
else:
output_embed = self.shared_embedding
- y = targets.astype("int32")
+ y = targets.astype('int32')
if not cfg.decode:
y = shift_right(y)
y = output_embed(y)
- y = AddPositionEmbs(
- config=cfg, decode=cfg.decode, name="posembed_output")(
- y, inputs_positions=targets_positions)
+ y = AddPositionEmbs(config=cfg, decode=cfg.decode, name='posembed_output')(
+ y, inputs_positions=targets_positions
+ )
y = Dropout(rate=dropout_rate)(
- y, deterministic=cfg.deterministic, rate=dropout_rate)
+ y, deterministic=cfg.deterministic, rate=dropout_rate
+ )
y = y.astype(cfg.dtype)
# Target-Input Decoder
for lyr in range(cfg.num_layers):
- y = EncoderDecoder1DBlock(
- config=cfg, name=f"encoderdecoderblock_{lyr}")(
- y,
- encoded,
- decoder_mask=decoder_mask,
- encoder_decoder_mask=encoder_decoder_mask,
- dropout_rate=dropout_rate,
- )
+ y = EncoderDecoder1DBlock(config=cfg, name=f'encoderdecoderblock_{lyr}')(
+ y,
+ encoded,
+ decoder_mask=decoder_mask,
+ encoder_decoder_mask=encoder_decoder_mask,
+ dropout_rate=dropout_rate,
+ )
y = (
- nn.LayerNorm(dtype=cfg.dtype, name="encoderdecoder_layernorm")(y)
- if cfg.pre_ln else y)
+ nn.LayerNorm(dtype=cfg.dtype, name='encoderdecoder_layernorm')(y)
+ if cfg.pre_ln
+ else y
+ )
# Use the transpose of embedding matrix for logit transform.
logits = output_embed.attend(y.astype(jnp.float32))
@@ -469,6 +483,7 @@ class Transformer(nn.Module):
Attributes:
config: TransformerConfig dataclass containing hyperparameters.
"""
+
config: TransformerConfig
def setup(self):
@@ -477,22 +492,26 @@ def setup(self):
if cfg.share_embeddings:
if cfg.vocab_size is not None:
assert cfg.vocab_size == cfg.vocab_size, (
- "can't share embedding with different vocab sizes.")
+ "can't share embedding with different vocab sizes."
+ )
self.shared_embedding = nn.Embed(
- num_embeddings=cfg.vocab_size,
- features=cfg.emb_dim,
- embedding_init=nn.initializers.normal(stddev=1.0))
+ num_embeddings=cfg.vocab_size,
+ features=cfg.emb_dim,
+ embedding_init=nn.initializers.normal(stddev=1.0),
+ )
else:
self.shared_embedding = None
self.encoder = Encoder(config=cfg, shared_embedding=self.shared_embedding)
self.decoder = Decoder(config=cfg, shared_embedding=self.shared_embedding)
- def encode(self,
- inputs,
- inputs_positions=None,
- inputs_segmentation=None,
- dropout_rate=DROPOUT_RATE):
+ def encode(
+ self,
+ inputs,
+ inputs_positions=None,
+ inputs_segmentation=None,
+ dropout_rate=DROPOUT_RATE,
+ ):
"""Applies Transformer encoder-branch on the inputs.
Args:
@@ -506,31 +525,33 @@ def encode(self,
cfg = self.config
# Make padding attention mask.
encoder_mask = nn.make_attention_mask(
- inputs > 0, inputs > 0, dtype=cfg.dtype)
+ inputs > 0, inputs > 0, dtype=cfg.dtype
+ )
# Add segmentation block-diagonal attention mask if using segmented data.
if inputs_segmentation is not None:
encoder_mask = nn.combine_masks(
- encoder_mask,
- nn.make_attention_mask(
- inputs_segmentation,
- inputs_segmentation,
- jnp.equal,
- dtype=cfg.dtype))
+ encoder_mask,
+ nn.make_attention_mask(
+ inputs_segmentation, inputs_segmentation, jnp.equal, dtype=cfg.dtype
+ ),
+ )
return self.encoder(
- inputs,
- inputs_positions=inputs_positions,
- encoder_mask=encoder_mask,
- dropout_rate=dropout_rate)
+ inputs,
+ inputs_positions=inputs_positions,
+ encoder_mask=encoder_mask,
+ dropout_rate=dropout_rate,
+ )
def decode(
- self,
- encoded,
- inputs, # only needed for masks
- targets,
- targets_positions=None,
- inputs_segmentation=None,
- targets_segmentation=None,
- dropout_rate=DROPOUT_RATE):
+ self,
+ encoded,
+ inputs, # only needed for masks
+ targets,
+ targets_positions=None,
+ inputs_segmentation=None,
+ targets_segmentation=None,
+ dropout_rate=DROPOUT_RATE,
+ ):
"""Applies Transformer decoder-branch on encoded-input and target.
Args:
@@ -550,47 +571,51 @@ def decode(
if cfg.decode:
decoder_mask = None
encoder_decoder_mask = nn.make_attention_mask(
- jnp.ones_like(targets) > 0, inputs > 0, dtype=cfg.dtype)
+ jnp.ones_like(targets) > 0, inputs > 0, dtype=cfg.dtype
+ )
else:
decoder_mask = nn.combine_masks(
- nn.make_attention_mask(targets > 0, targets > 0, dtype=cfg.dtype),
- nn.make_causal_mask(targets, dtype=cfg.dtype))
+ nn.make_attention_mask(targets > 0, targets > 0, dtype=cfg.dtype),
+ nn.make_causal_mask(targets, dtype=cfg.dtype),
+ )
encoder_decoder_mask = nn.make_attention_mask(
- targets > 0, inputs > 0, dtype=cfg.dtype)
+ targets > 0, inputs > 0, dtype=cfg.dtype
+ )
# Add segmentation block-diagonal attention masks if using segmented data.
if inputs_segmentation is not None:
decoder_mask = nn.combine_masks(
- decoder_mask,
- nn.make_attention_mask(
- targets_segmentation,
- targets_segmentation,
- jnp.equal,
- dtype=cfg.dtype))
+ decoder_mask,
+ nn.make_attention_mask(
+ targets_segmentation, targets_segmentation, jnp.equal, dtype=cfg.dtype
+ ),
+ )
encoder_decoder_mask = nn.combine_masks(
- encoder_decoder_mask,
- nn.make_attention_mask(
- targets_segmentation,
- inputs_segmentation,
- jnp.equal,
- dtype=cfg.dtype))
+ encoder_decoder_mask,
+ nn.make_attention_mask(
+ targets_segmentation, inputs_segmentation, jnp.equal, dtype=cfg.dtype
+ ),
+ )
logits = self.decoder(
- encoded,
- targets,
- targets_positions=targets_positions,
- decoder_mask=decoder_mask,
- encoder_decoder_mask=encoder_decoder_mask,
- dropout_rate=dropout_rate)
+ encoded,
+ targets,
+ targets_positions=targets_positions,
+ decoder_mask=decoder_mask,
+ encoder_decoder_mask=encoder_decoder_mask,
+ dropout_rate=dropout_rate,
+ )
return logits.astype(self.config.dtype)
- def __call__(self,
- inputs,
- targets,
- inputs_positions=None,
- targets_positions=None,
- inputs_segmentation=None,
- targets_segmentation=None,
- dropout_rate=DROPOUT_RATE):
+ def __call__(
+ self,
+ inputs,
+ targets,
+ inputs_positions=None,
+ targets_positions=None,
+ inputs_segmentation=None,
+ targets_segmentation=None,
+ dropout_rate=DROPOUT_RATE,
+ ):
"""Applies Transformer model on the inputs.
Args:
@@ -605,16 +630,18 @@ def __call__(self,
logits array from full transformer.
"""
encoded = self.encode(
- inputs,
- inputs_positions=inputs_positions,
- inputs_segmentation=inputs_segmentation,
- dropout_rate=dropout_rate)
+ inputs,
+ inputs_positions=inputs_positions,
+ inputs_segmentation=inputs_segmentation,
+ dropout_rate=dropout_rate,
+ )
return self.decode(
- encoded,
- inputs, # only used for masks
- targets,
- targets_positions=targets_positions,
- inputs_segmentation=inputs_segmentation,
- targets_segmentation=targets_segmentation,
- dropout_rate=dropout_rate)
+ encoded,
+ inputs, # only used for masks
+ targets,
+ targets_positions=targets_positions,
+ inputs_segmentation=inputs_segmentation,
+ targets_segmentation=targets_segmentation,
+ dropout_rate=dropout_rate,
+ )
diff --git a/algoperf/workloads/wmt/wmt_jax/workload.py b/algoperf/workloads/wmt/wmt_jax/workload.py
index d402f9d95..51d8a85a7 100644
--- a/algoperf/workloads/wmt/wmt_jax/workload.py
+++ b/algoperf/workloads/wmt/wmt_jax/workload.py
@@ -1,23 +1,21 @@
"""WMT workload implemented in Jax."""
-from dataclasses import replace
import functools
+from dataclasses import replace
from typing import Any, Dict, Iterator, Optional, Tuple
-from absl import logging
-from flax import jax_utils
-from flax import linen as nn
-from flax.training import common_utils
import jax
import jax.numpy as jnp
import numpy as np
import optax
+from absl import logging
+from flax import jax_utils
+from flax import linen as nn
+from flax.training import common_utils
-from algoperf import param_utils
-from algoperf import spec
+from algoperf import param_utils, spec
from algoperf.workloads.wmt import bleu
-from algoperf.workloads.wmt.wmt_jax import decode
-from algoperf.workloads.wmt.wmt_jax import models
+from algoperf.workloads.wmt.wmt_jax import decode, models
from algoperf.workloads.wmt.workload import BaseWmtWorkload
@@ -31,11 +29,12 @@ class WmtWorkload(BaseWmtWorkload):
"""WMT Jax workload."""
def compute_weighted_cross_entropy(
- self,
- logits: spec.Tensor,
- targets: spec.Tensor,
- weights: Optional[spec.Tensor] = None,
- label_smoothing: float = 0.1) -> Dict[str, spec.Tensor]: # differentiable
+ self,
+ logits: spec.Tensor,
+ targets: spec.Tensor,
+ weights: Optional[spec.Tensor] = None,
+ label_smoothing: float = 0.1,
+ ) -> Dict[str, spec.Tensor]: # differentiable
"""Compute weighted cross entropy and entropy for log probs and targets.
Args:
@@ -50,76 +49,86 @@ def compute_weighted_cross_entropy(
valid examples in batch, 'per_example': 1-d array of per-example losses}
"""
if logits.ndim != targets.ndim + 1:
- raise ValueError(f'Incorrect shapes. Got shape {logits.shape} logits and '
- f'{targets.shape} targets.')
+ raise ValueError(
+ f'Incorrect shapes. Got shape {logits.shape} logits and '
+ f'{targets.shape} targets.'
+ )
smoothed_targets = optax.smooth_labels(
- common_utils.onehot(targets, self._vocab_size), label_smoothing)
+ common_utils.onehot(targets, self._vocab_size), label_smoothing
+ )
per_example_losses = -jnp.sum(
- smoothed_targets * nn.log_softmax(logits), axis=-1)
+ smoothed_targets * nn.log_softmax(logits), axis=-1
+ )
if weights is None:
weights = jnp.ones_like(targets)
- per_example_losses = jnp.where(weights, per_example_losses, 0.)
+ per_example_losses = jnp.where(weights, per_example_losses, 0.0)
summed_loss = per_example_losses.sum()
n_valid_examples = weights.sum()
return {
- 'summed': summed_loss,
- 'n_valid_examples': n_valid_examples,
- 'per_example': per_example_losses,
+ 'summed': summed_loss,
+ 'n_valid_examples': n_valid_examples,
+ 'per_example': per_example_losses,
}
@functools.partial(
- jax.pmap, axis_name='batch', static_broadcasted_argnums=(0,))
+ jax.pmap, axis_name='batch', static_broadcasted_argnums=(0,)
+ )
def eval_step_pmapped(
- self, params: spec.ParameterContainer,
- batch: Dict[str, spec.Tensor]) -> Dict[str, spec.Tensor]:
+ self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor]
+ ) -> Dict[str, spec.Tensor]:
"""Calculate evaluation metrics on a batch."""
inputs = batch['inputs']
targets = batch['targets']
weights = batch['weights']
logits = self._eval_model.apply({'params': params}, inputs, targets)
- summed_loss = self.compute_weighted_cross_entropy(logits,
- targets,
- weights,
- 0.0)['summed']
+ summed_loss = self.compute_weighted_cross_entropy(
+ logits, targets, weights, 0.0
+ )['summed']
acc_sum, weight_sum = self.compute_weighted_accuracy(
- logits, targets, weights)
+ logits, targets, weights
+ )
return {
- 'loss': summed_loss,
- 'accuracy': acc_sum,
- 'denominator': weight_sum,
+ 'loss': summed_loss,
+ 'accuracy': acc_sum,
+ 'denominator': weight_sum,
}
- def eval_step(self,
- params: spec.ParameterContainer,
- batch: Dict[str, spec.Tensor]) -> Dict[str, spec.Tensor]:
+ def eval_step(
+ self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor]
+ ) -> Dict[str, spec.Tensor]:
replicated_eval_metrics = self.eval_step_pmapped(params, batch)
return jax.tree.map(lambda x: jnp.sum(x, axis=0), replicated_eval_metrics)
@functools.partial(
- jax.pmap, axis_name='batch', static_broadcasted_argnums=(0,))
- def initialize_cache(self,
- inputs: spec.Tensor,
- max_decode_len: int = 256) -> Dict[str, spec.Tensor]:
+ jax.pmap, axis_name='batch', static_broadcasted_argnums=(0,)
+ )
+ def initialize_cache(
+ self, inputs: spec.Tensor, max_decode_len: int = 256
+ ) -> Dict[str, spec.Tensor]:
"""Initialize a cache for a given input shape and max decode length."""
config = models.TransformerConfig(deterministic=True, decode=True)
target_shape = (inputs.shape[0], max_decode_len) + inputs.shape[2:]
initial_variables = models.Transformer(config).init(
- jax.random.PRNGKey(0),
- jnp.ones(inputs.shape, jnp.float32),
- jnp.ones(target_shape, jnp.float32))
+ jax.random.PRNGKey(0),
+ jnp.ones(inputs.shape, jnp.float32),
+ jnp.ones(target_shape, jnp.float32),
+ )
return initial_variables['cache']
# eos_id, max_decode_len are constant.
@functools.partial(
- jax.pmap, axis_name='batch', static_broadcasted_argnums=(0, 4, 5))
- def predict_step(self,
- inputs: spec.Tensor,
- params: spec.ParameterContainer,
- cache: Dict[str, spec.Tensor],
- eos_id: int,
- max_decode_len: int,
- beam_size: int = 4) -> spec.Tensor:
+ jax.pmap, axis_name='batch', static_broadcasted_argnums=(0, 4, 5)
+ )
+ def predict_step(
+ self,
+ inputs: spec.Tensor,
+ params: spec.ParameterContainer,
+ cache: Dict[str, spec.Tensor],
+ eos_id: int,
+ max_decode_len: int,
+ beam_size: int = 4,
+ ) -> spec.Tensor:
"""Predict translation with fast decoding beam search on a batch."""
config = replace(self._eval_model.config, decode=True)
# Prepare transformer fast-decoder call for beam search: for beam search, we
@@ -129,27 +138,29 @@ def predict_step(self,
# i.e. if we denote each batch element subtensor as el[n]:
# [el0, el1, el2] --> beamsize=2 --> [el0,el0,el1,el1,el2,el2]
encoded_inputs = decode.flat_batch_beam_expand(
- models.Transformer(config).apply({'params': params},
- inputs,
- method=models.Transformer.encode),
- beam_size)
+ models.Transformer(config).apply(
+ {'params': params}, inputs, method=models.Transformer.encode
+ ),
+ beam_size,
+ )
raw_inputs = decode.flat_batch_beam_expand(inputs, beam_size)
def tokens_ids_to_logits(
- flat_ids: spec.Tensor, flat_cache: Dict[str, spec.Tensor]
+ flat_ids: spec.Tensor, flat_cache: Dict[str, spec.Tensor]
) -> Tuple[spec.Tensor, Dict[str, spec.Tensor]]:
"""Token slice to logits from decoder model."""
# --> [batch * beam, 1, vocab]
flat_logits, new_vars = models.Transformer(config).apply(
- {
- 'params': params,
- 'cache': flat_cache,
- },
- encoded_inputs,
- raw_inputs, # only needed for input padding mask
- flat_ids,
- mutable=['cache'],
- method=models.Transformer.decode)
+ {
+ 'params': params,
+ 'cache': flat_cache,
+ },
+ encoded_inputs,
+ raw_inputs, # only needed for input padding mask
+ flat_ids,
+ mutable=['cache'],
+ method=models.Transformer.decode,
+ )
new_flat_cache = new_vars['cache']
# Remove singleton sequence-length dimension:
# [batch * beam, 1, vocab] --> [batch * beam, vocab]
@@ -159,35 +170,36 @@ def tokens_ids_to_logits(
# Using the above-defined single-step decoder function, run a
# beam search over possible sequences given input encoding.
beam_seqs, _ = decode.beam_search(
- inputs,
- cache,
- tokens_ids_to_logits,
- beam_size=beam_size,
- alpha=0.6,
- eos_id=eos_id,
- max_decode_len=max_decode_len)
+ inputs,
+ cache,
+ tokens_ids_to_logits,
+ beam_size=beam_size,
+ alpha=0.6,
+ eos_id=eos_id,
+ max_decode_len=max_decode_len,
+ )
# Beam search returns [n_batch, n_beam, n_length + 1] with beam dimension
# sorted in increasing order of log-probability.
# Return the highest scoring beam sequence, drop first dummy 0 token.
return beam_seqs[:, -1, 1:]
- def translate_and_calculate_bleu(self,
- params: spec.ParameterContainer,
- ds_iter: Iterator,
- num_batches: int,
- max_predict_length: int) -> spec.Tensor:
+ def translate_and_calculate_bleu(
+ self,
+ params: spec.ParameterContainer,
+ ds_iter: Iterator,
+ num_batches: int,
+ max_predict_length: int,
+ ) -> spec.Tensor:
"""Translates the `predict_ds` and calculates the BLEU score."""
logging.info('Translating evaluation dataset.')
references, predictions = [], []
for _ in range(num_batches):
pred_batch = next(ds_iter)
cache = self.initialize_cache(pred_batch['inputs'])
- predicted = self.predict_step(pred_batch['inputs'],
- params,
- cache,
- decode.EOS_ID,
- max_predict_length)
+ predicted = self.predict_step(
+ pred_batch['inputs'], params, cache, decode.EOS_ID, max_predict_length
+ )
predicted = _to_host(predicted)
targets = _to_host(pred_batch['targets'])
# Find actual batch size, ignoring the potential padding.
@@ -219,18 +231,20 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
raise ValueError(f'Unknown activation function {self.activation}.')
model_config = models.TransformerConfig(
- pre_ln=self.pre_ln,
- attention_temp=self.attention_temp,
- activation=activation,
- glu=self.glu)
+ pre_ln=self.pre_ln,
+ attention_temp=self.attention_temp,
+ activation=activation,
+ glu=self.glu,
+ )
self._train_model = models.Transformer(model_config)
eval_config = replace(model_config, deterministic=True)
self._eval_model = models.Transformer(eval_config)
params_rng, _ = jax.random.split(rng)
- initial_variables = jax.jit(
- self._eval_model.init)({'params': params_rng},
- jnp.ones(input_shape, jnp.float32),
- jnp.ones(target_shape, jnp.float32))
+ initial_variables = jax.jit(self._eval_model.init)(
+ {'params': params_rng},
+ jnp.ones(input_shape, jnp.float32),
+ jnp.ones(target_shape, jnp.float32),
+ )
initial_params = initial_variables['params']
self._param_shapes = param_utils.jax_param_shapes(initial_params)
@@ -241,14 +255,14 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key == 'shared_embedding'
def model_fn(
- self,
- params: spec.ParameterContainer,
- augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- mode: spec.ForwardPassMode,
- rng: spec.RandomState,
- update_batch_norm: bool,
- dropout_rate: [float] = models.DROPOUT_RATE
+ self,
+ params: spec.ParameterContainer,
+ augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ mode: spec.ForwardPassMode,
+ rng: spec.RandomState,
+ update_batch_norm: bool,
+ dropout_rate: [float] = models.DROPOUT_RATE,
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del update_batch_norm
@@ -256,33 +270,39 @@ def model_fn(
inputs = augmented_and_preprocessed_input_batch.get('inputs', None)
targets = augmented_and_preprocessed_input_batch.get('targets', None)
inputs_positions = augmented_and_preprocessed_input_batch.get(
- 'inputs_position', None)
+ 'inputs_position', None
+ )
targets_positions = augmented_and_preprocessed_input_batch.get(
- 'targets_position', None)
+ 'targets_position', None
+ )
inputs_segmentations = augmented_and_preprocessed_input_batch.get(
- 'inputs_segmentation', None)
+ 'inputs_segmentation', None
+ )
targets_segmentations = augmented_and_preprocessed_input_batch.get(
- 'targets_segmentation', None)
+ 'targets_segmentation', None
+ )
if mode == spec.ForwardPassMode.TRAIN:
model = self._train_model
else:
model = self._eval_model
- logits_batch = model.apply({'params': params},
- inputs,
- targets,
- inputs_positions=inputs_positions,
- targets_positions=targets_positions,
- inputs_segmentation=inputs_segmentations,
- targets_segmentation=targets_segmentations,
- rngs={'dropout': rng},
- dropout_rate=dropout_rate)
+ logits_batch = model.apply(
+ {'params': params},
+ inputs,
+ targets,
+ inputs_positions=inputs_positions,
+ targets_positions=targets_positions,
+ inputs_segmentation=inputs_segmentations,
+ targets_segmentation=targets_segmentations,
+ rngs={'dropout': rng},
+ dropout_rate=dropout_rate,
+ )
return logits_batch, None
def _normalize_eval_metrics(
- self, num_examples: int, total_metrics: Dict[str,
- Any]) -> Dict[str, float]:
+ self, num_examples: int, total_metrics: Dict[str, Any]
+ ) -> Dict[str, float]:
"""Normalize eval metrics."""
del num_examples
eval_denominator = total_metrics.pop('denominator')
diff --git a/algoperf/workloads/wmt/wmt_pytorch/decode.py b/algoperf/workloads/wmt/wmt_pytorch/decode.py
index 26ff36650..7974412d7 100644
--- a/algoperf/workloads/wmt/wmt_pytorch/decode.py
+++ b/algoperf/workloads/wmt/wmt_pytorch/decode.py
@@ -21,8 +21,9 @@
NEG_INF = torch.tensor(-1.0e7, device=DEVICE)
-def brevity_penalty(alpha: float, length: Union[int,
- torch.Tensor]) -> torch.Tensor:
+def brevity_penalty(
+ alpha: float, length: Union[int, torch.Tensor]
+) -> torch.Tensor:
"""Brevity penalty function for beam search penalizing short sequences.
Args:
@@ -57,8 +58,9 @@ def flatten_beam_dim(x: torch.Tensor) -> torch.Tensor:
return x.view(-1, *x.shape[2:])
-def unflatten_beam_dim(x: torch.Tensor, batch_size: int,
- beam_size: int) -> torch.Tensor:
+def unflatten_beam_dim(
+ x: torch.Tensor, batch_size: int, beam_size: int
+) -> torch.Tensor:
"""Unflattens the first, flat batch*beam dimension of a non-scalar tensor."""
if x.dim() < 2: # ignore scalars (e.g. cache index)
return x
@@ -71,10 +73,12 @@ def flat_batch_beam_expand(x: torch.Tensor, beam_size: int) -> torch.Tensor:
return flatten_beam_dim(add_beam_dim(x, beam_size))
-def gather_beams(nested: Dict[str, Any],
- beam_indices: torch.Tensor,
- batch_size: int,
- new_beam_size: int) -> Dict[str, Any]:
+def gather_beams(
+ nested: Dict[str, Any],
+ beam_indices: torch.Tensor,
+ batch_size: int,
+ new_beam_size: int,
+) -> Dict[str, Any]:
"""Gathers the beam slices indexed by beam_indices into new beam tensor.
Args:
@@ -88,10 +92,13 @@ def gather_beams(nested: Dict[str, Any],
[batch_size, old_beam_size, ...] --> [batch_size, new_beam_size, ...]
"""
batch_indices = torch.reshape(
- torch.div(
- torch.arange(batch_size * new_beam_size, device=DEVICE),
- new_beam_size,
- rounding_mode='floor'), (batch_size, new_beam_size))
+ torch.div(
+ torch.arange(batch_size * new_beam_size, device=DEVICE),
+ new_beam_size,
+ rounding_mode='floor',
+ ),
+ (batch_size, new_beam_size),
+ )
def gather_fn(x):
if x.dim() < 2: # ignore scalars (e.g. cache index)
@@ -101,10 +108,12 @@ def gather_fn(x):
return jax.tree.map(gather_fn, nested)
-def gather_topk_beams(nested: Dict[str, Any],
- score_or_log_prob: torch.Tensor,
- batch_size: int,
- new_beam_size: int) -> Dict[str, Any]:
+def gather_topk_beams(
+ nested: Dict[str, Any],
+ score_or_log_prob: torch.Tensor,
+ batch_size: int,
+ new_beam_size: int,
+) -> Dict[str, Any]:
"""Gathers the top-k beam slices given by score_or_log_prob array.
Args:
@@ -129,6 +138,7 @@ def gather_topk_beams(nested: Dict[str, Any],
@dataclass
class BeamState:
"""Holds beam search state data."""
+
# The position of the decoding loop in the length dimension.
cur_index: torch.Tensor # scalar int32: current decoded length index.
# The active sequence log probabilities and finished sequence scores.
@@ -143,49 +153,52 @@ class BeamState:
cache: Dict[str, Any] # Any dict (of dicts), with torch.Tensors as leafs.
-def beam_init(batch_size: int,
- beam_size: int,
- max_decode_len: int,
- cache: Dict[str, Any]) -> BeamState:
+def beam_init(
+ batch_size: int, beam_size: int, max_decode_len: int, cache: Dict[str, Any]
+) -> BeamState:
"""Initializes the beam search state data structure."""
cur_index0 = torch.tensor(0, device=DEVICE)
live_logprobs0 = torch.tile(
- torch.tensor([0.0] + [NEG_INF] * (beam_size - 1), device=DEVICE),
- [batch_size, 1])
+ torch.tensor([0.0] + [NEG_INF] * (beam_size - 1), device=DEVICE),
+ [batch_size, 1],
+ )
finished_scores0 = (
- torch.ones((batch_size, beam_size), device=DEVICE) * NEG_INF)
- live_seqs0 = torch.zeros((batch_size, beam_size, max_decode_len),
- dtype=torch.int32,
- device=DEVICE)
- finished_seqs0 = torch.zeros((batch_size, beam_size, max_decode_len),
- dtype=torch.int32,
- device=DEVICE)
- finished_flags0 = torch.zeros((batch_size, beam_size),
- dtype=torch.bool,
- device=DEVICE)
+ torch.ones((batch_size, beam_size), device=DEVICE) * NEG_INF
+ )
+ live_seqs0 = torch.zeros(
+ (batch_size, beam_size, max_decode_len), dtype=torch.int32, device=DEVICE
+ )
+ finished_seqs0 = torch.zeros(
+ (batch_size, beam_size, max_decode_len), dtype=torch.int32, device=DEVICE
+ )
+ finished_flags0 = torch.zeros(
+ (batch_size, beam_size), dtype=torch.bool, device=DEVICE
+ )
# add beam dimension to attention cache pytree elements
beam_cache0 = jax.tree.map(lambda x: add_beam_dim(x, beam_size), cache)
return BeamState(
- cur_index=cur_index0,
- live_logprobs=live_logprobs0,
- finished_scores=finished_scores0,
- live_seqs=live_seqs0,
- finished_seqs=finished_seqs0,
- finished_flags=finished_flags0,
- cache=beam_cache0)
+ cur_index=cur_index0,
+ live_logprobs=live_logprobs0,
+ finished_scores=finished_scores0,
+ live_seqs=live_seqs0,
+ finished_seqs=finished_seqs0,
+ finished_flags=finished_flags0,
+ cache=beam_cache0,
+ )
# Beam search routine:
def beam_search(
- inputs: torch.Tensor,
- cache: Optional[Dict[str, Any]],
- tokens_to_logits: Callable,
- beam_size: int = 4,
- alpha: float = 0.6,
- eos_id: int = EOS_ID,
- max_decode_len: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
+ inputs: torch.Tensor,
+ cache: Optional[Dict[str, Any]],
+ tokens_to_logits: Callable,
+ beam_size: int = 4,
+ alpha: float = 0.6,
+ eos_id: int = EOS_ID,
+ max_decode_len: Optional[int] = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
"""Beam search for transformer machine translation.
Args:
@@ -211,10 +224,9 @@ def beam_search(
end_marker = torch.tensor(eos_id, device=DEVICE)
# initialize beam search state
- beam_search_init_state = beam_init(batch_size,
- beam_size,
- max_decode_len,
- cache)
+ beam_search_init_state = beam_init(
+ batch_size, beam_size, max_decode_len, cache
+ )
def beam_search_loop_cond_fn(state: BeamState) -> bool:
"""Beam search loop termination condition."""
@@ -227,11 +239,12 @@ def beam_search_loop_cond_fn(state: BeamState) -> bool:
best_live_scores = state.live_logprobs[:, -1:] / min_brevity_penalty
# Get the worst scores from finished sequences.
worst_finished_scores, _ = torch.min(
- state.finished_scores, dim=1, keepdim=True)
+ state.finished_scores, dim=1, keepdim=True
+ )
# Mask out scores from slots without any actual finished sequences.
- worst_finished_scores = torch.where(state.finished_flags,
- worst_finished_scores,
- NEG_INF)
+ worst_finished_scores = torch.where(
+ state.finished_flags, worst_finished_scores, NEG_INF
+ )
# If no best possible live score is better than current worst finished
# scores, the search cannot improve the finished set further.
search_terminated = torch.all(worst_finished_scores > best_live_scores)
@@ -248,7 +261,8 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState:
# --> [batch * beam, 1]
cur_index = state.cur_index
flat_ids = flatten_beam_dim(
- state.live_seqs[:batch_size, :beam_size, cur_index:cur_index + 1])
+ state.live_seqs[:batch_size, :beam_size, cur_index : cur_index + 1]
+ )
# Flatten beam dimension into batch to be compatible with model.
# {[batch, beam, ...], ...} --> {[batch * beam, ...], ...}
flat_cache = jax.tree.map(flatten_beam_dim, state.cache)
@@ -263,7 +277,8 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState:
# Unflatten beam dimension in attention cache arrays
# {[batch * beam, ...], ...} --> {[batch, beam, ...], ...}
new_cache = jax.tree.map(
- lambda x: unflatten_beam_dim(x, batch_size, beam_size), new_flat_cache)
+ lambda x: unflatten_beam_dim(x, batch_size, beam_size), new_flat_cache
+ )
# Gather log probabilities from logits
candidate_log_probs = F.log_softmax(logits, dim=-1)
@@ -287,13 +302,13 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState:
topk_log_probs, topk_indices = torch.topk(flat_log_probs, k=beams_to_keep)
# Recover the beam index by floor division.
topk_beam_indices = torch.div(
- topk_indices, vocab_size, rounding_mode='floor')
+ topk_indices, vocab_size, rounding_mode='floor'
+ )
# Gather 2*k top beams.
# --> [batch, 2*beams, length]
- topk_seq = gather_beams(state.live_seqs,
- topk_beam_indices,
- batch_size,
- beams_to_keep)
+ topk_seq = gather_beams(
+ state.live_seqs, topk_beam_indices, batch_size, beams_to_keep
+ )
# Append the most probable 2*K token IDs to the top 2*K sequences
# Recover token id by modulo division and expand Id array for broadcasting.
@@ -301,11 +316,11 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState:
topk_ids = torch.unsqueeze(topk_indices % vocab_size, dim=2)
# Update sequences for the 2*K top-k new sequences.
# --> [batch, 2*beams, length]
- topk_seq[:, :, cur_index + 1:] = topk_ids
+ topk_seq[:, :, cur_index + 1 :] = topk_ids
# Update LIVE (in-progress) sequences:
# Did any of these sequences reach an end marker?
# --> [batch, 2*beams]
- newly_finished = (topk_seq[:, :, cur_index + 1] == end_marker)
+ newly_finished = topk_seq[:, :, cur_index + 1] == end_marker
# To prevent these newly finished sequences from being added to the LIVE
# set of active beam search sequences, set their log probs to a very large
# negative value.
@@ -316,22 +331,20 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState:
new_topk_indices = torch.flip(new_topk_indices, (1,))
# Gather the top k beams (from top 2*k beams).
# --> [batch, beams, length], [batch, beams]
- top_alive_seq, top_alive_log_probs = gather_beams([topk_seq, new_log_probs],
- new_topk_indices,
- batch_size, beam_size)
+ top_alive_seq, top_alive_log_probs = gather_beams(
+ [topk_seq, new_log_probs], new_topk_indices, batch_size, beam_size
+ )
# Determine the top k beam indices from the original set of all beams.
# --> [batch, beams]
- top_alive_indices = gather_beams(topk_beam_indices,
- new_topk_indices,
- batch_size,
- beam_size)
+ top_alive_indices = gather_beams(
+ topk_beam_indices, new_topk_indices, batch_size, beam_size
+ )
# With these, gather the top k beam-associated caches.
# --> {[batch, beams, ...], ...}
- top_alive_cache = gather_beams(new_cache,
- top_alive_indices,
- batch_size,
- beam_size)
+ top_alive_cache = gather_beams(
+ new_cache, top_alive_indices, batch_size, beam_size
+ )
# Update FINISHED (reached end of sentence) sequences:
# Calculate new seq scores from log probabilities.
@@ -344,24 +357,33 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState:
# new finished sequence scores to existing finished scores and select the
# best from the new set of beams.
finished_seqs = torch.cat( # --> [batch, 3*beams, length]
- [state.finished_seqs, topk_seq], dim=1)
+ [state.finished_seqs, topk_seq], dim=1
+ )
finished_scores = torch.cat( # --> [batch, 3*beams]
- [state.finished_scores, new_scores], dim=1)
+ [state.finished_scores, new_scores], dim=1
+ )
finished_flags = torch.cat( # --> [batch, 3*beams]
- [state.finished_flags, newly_finished], dim=1)
+ [state.finished_flags, newly_finished], dim=1
+ )
# --> [batch, beams, length], [batch, beams], [batch, beams]
top_finished_seq, top_finished_scores, top_finished_flags = (
- gather_topk_beams([finished_seqs, finished_scores, finished_flags],
- finished_scores, batch_size, beam_size))
+ gather_topk_beams(
+ [finished_seqs, finished_scores, finished_flags],
+ finished_scores,
+ batch_size,
+ beam_size,
+ )
+ )
return BeamState(
- cur_index=cur_index + 1,
- live_logprobs=top_alive_log_probs,
- finished_scores=top_finished_scores,
- live_seqs=top_alive_seq,
- finished_seqs=top_finished_seq,
- finished_flags=top_finished_flags,
- cache=top_alive_cache)
+ cur_index=cur_index + 1,
+ live_logprobs=top_alive_log_probs,
+ finished_scores=top_finished_scores,
+ live_seqs=top_alive_seq,
+ finished_seqs=top_finished_seq,
+ finished_flags=top_finished_flags,
+ cache=top_alive_cache,
+ )
state = beam_search_init_state
while beam_search_loop_cond_fn(state):
@@ -373,12 +395,16 @@ def beam_search_loop_body_fn(state: BeamState) -> BeamState:
# --> [batch]
none_finished = torch.any(final_state.finished_flags, dim=1)
# --> [batch, beams, length]
- finished_seqs = torch.where(none_finished[:, None, None],
- final_state.finished_seqs,
- final_state.live_seqs)
+ finished_seqs = torch.where(
+ none_finished[:, None, None],
+ final_state.finished_seqs,
+ final_state.live_seqs,
+ )
# --> [batch, beams]
- finished_scores = torch.where(none_finished[:, None],
- final_state.finished_scores,
- final_state.live_logprobs)
+ finished_scores = torch.where(
+ none_finished[:, None],
+ final_state.finished_scores,
+ final_state.live_logprobs,
+ )
return finished_seqs, finished_scores
diff --git a/algoperf/workloads/wmt/wmt_pytorch/models.py b/algoperf/workloads/wmt/wmt_pytorch/models.py
index 0de719c4b..430cc945b 100644
--- a/algoperf/workloads/wmt/wmt_pytorch/models.py
+++ b/algoperf/workloads/wmt/wmt_pytorch/models.py
@@ -3,11 +3,9 @@
from typing import Any, Callable, Dict, Optional, Tuple, Union
import torch
-from torch import nn
-from torch import Tensor
import torch.nn.functional as F
-from torch.nn.init import normal_
-from torch.nn.init import xavier_uniform_
+from torch import Tensor, nn
+from torch.nn.init import normal_, xavier_uniform_
DROPOUT_RATE = 0.1
@@ -23,7 +21,8 @@ def make_causal_mask(x: Tensor, device: str = 'cuda:0') -> Tensor:
A `[batch..., len, len]` shaped causal attention mask.
"""
idxs = torch.broadcast_to(
- torch.arange(x.shape[-1], dtype=torch.int32, device=device), x.shape)
+ torch.arange(x.shape[-1], dtype=torch.int32, device=device), x.shape
+ )
return torch.greater_equal(idxs.unsqueeze(-1), idxs.unsqueeze(-2))
@@ -33,55 +32,60 @@ def make_src_mask(src, inputs_segmentation, nhead):
# Add segmentation block-diagonal attention mask if using segmented data.
if inputs_segmentation is not None:
src_mask = torch.logical_and(
- src_mask,
- torch.eq(
- inputs_segmentation.unsqueeze(-1),
- inputs_segmentation.unsqueeze(-2)))
+ src_mask,
+ torch.eq(
+ inputs_segmentation.unsqueeze(-1), inputs_segmentation.unsqueeze(-2)
+ ),
+ )
# Flip values and ensure numerical stability.
src_mask = torch.repeat_interleave(
- torch.logical_not(src_mask), repeats=nhead, dim=0)
+ torch.logical_not(src_mask), repeats=nhead, dim=0
+ )
new_src_mask = torch.zeros_like(src_mask, dtype=torch.float32)
new_src_mask.masked_fill_(src_mask, -1e10)
return new_src_mask
-def make_tgt_and_memory_mask(tgt,
- src,
- inputs_segmentation,
- targets_segmentation,
- decode,
- nhead):
- """ Utility for creating target and memory mask and adjust them for PyTorch
+def make_tgt_and_memory_mask(
+ tgt, src, inputs_segmentation, targets_segmentation, decode, nhead
+):
+ """Utility for creating target and memory mask and adjust them for PyTorch
Transformer API."""
if not decode:
tgt_mask = torch.logical_and(
- torch.mul((tgt > 0).unsqueeze(-1), (tgt > 0).unsqueeze(-2)),
- make_causal_mask(tgt, device=tgt.device))
+ torch.mul((tgt > 0).unsqueeze(-1), (tgt > 0).unsqueeze(-2)),
+ make_causal_mask(tgt, device=tgt.device),
+ )
memory_mask = torch.mul((tgt > 0).unsqueeze(-1), (src > 0).unsqueeze(-2))
else:
tgt_mask = None
- memory_mask = torch.mul((torch.ones_like(tgt) > 0).unsqueeze(-1),
- (src > 0).unsqueeze(-2))
+ memory_mask = torch.mul(
+ (torch.ones_like(tgt) > 0).unsqueeze(-1), (src > 0).unsqueeze(-2)
+ )
# Add segmentation block-diagonal attention masks if using segmented data.
if inputs_segmentation is not None:
tgt_mask = torch.logical_and(
- tgt_mask,
- torch.eq(
- targets_segmentation.unsqueeze(-1),
- targets_segmentation.unsqueeze(-2)))
+ tgt_mask,
+ torch.eq(
+ targets_segmentation.unsqueeze(-1), targets_segmentation.unsqueeze(-2)
+ ),
+ )
memory_mask = torch.logical_and(
- memory_mask,
- torch.eq(
- targets_segmentation.unsqueeze(-1),
- inputs_segmentation.unsqueeze(-2)))
+ memory_mask,
+ torch.eq(
+ targets_segmentation.unsqueeze(-1), inputs_segmentation.unsqueeze(-2)
+ ),
+ )
# Flip values and ensure numerical stability.
memory_mask = torch.repeat_interleave(
- torch.logical_not(memory_mask), repeats=nhead, dim=0)
+ torch.logical_not(memory_mask), repeats=nhead, dim=0
+ )
new_memory_mask = torch.zeros_like(memory_mask, dtype=torch.float32)
new_memory_mask.masked_fill_(memory_mask, -1e10)
if tgt_mask is not None:
tgt_mask = torch.repeat_interleave(
- torch.logical_not(tgt_mask), repeats=nhead, dim=0)
+ torch.logical_not(tgt_mask), repeats=nhead, dim=0
+ )
new_tgt_mask = torch.zeros_like(tgt_mask, dtype=torch.float32)
new_tgt_mask.masked_fill_(tgt_mask, -1e10)
tgt_mask = new_tgt_mask
@@ -100,38 +104,44 @@ def shift_right(x, axis=1):
class Transformer(nn.Module):
"""Transformer architecture based on the model from the WMT Jax workload."""
- def __init__(self,
- ntoken: int = 32000,
- d_model: int = 1024,
- nhead: int = 16,
- d_hid: int = 1024,
- nlayers: int = 6,
- activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
- glu: bool = False,
- layer_norm_eps: float = 1e-6,
- attention_temp: float = 1.0,
- pre_ln: bool = True):
+ def __init__(
+ self,
+ ntoken: int = 32000,
+ d_model: int = 1024,
+ nhead: int = 16,
+ d_hid: int = 1024,
+ nlayers: int = 6,
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
+ glu: bool = False,
+ layer_norm_eps: float = 1e-6,
+ attention_temp: float = 1.0,
+ pre_ln: bool = True,
+ ):
super().__init__()
self.pos_encoder = PositionalEncoding(d_model)
self.shared_embedding = nn.Embedding(ntoken, d_model)
- self.encoder = Encoder(d_model,
- nhead,
- d_hid,
- nlayers,
- activation,
- glu,
- layer_norm_eps,
- attention_temp,
- pre_ln)
- self.decoder = Decoder(d_model,
- nhead,
- d_hid,
- nlayers,
- activation,
- glu,
- layer_norm_eps,
- attention_temp,
- pre_ln)
+ self.encoder = Encoder(
+ d_model,
+ nhead,
+ d_hid,
+ nlayers,
+ activation,
+ glu,
+ layer_norm_eps,
+ attention_temp,
+ pre_ln,
+ )
+ self.decoder = Decoder(
+ d_model,
+ nhead,
+ d_hid,
+ nlayers,
+ activation,
+ glu,
+ layer_norm_eps,
+ attention_temp,
+ pre_ln,
+ )
# Share positional encoding and embedding between encoder and decoder.
self.encoder.pos_encoder = self.pos_encoder
self.encoder.shared_embedding = self.shared_embedding
@@ -148,15 +158,17 @@ def _reset_parameters(self):
if module.bias is not None:
normal_(module.bias, std=1e-6)
- def forward(self,
- src: Tensor,
- tgt: Tensor,
- inputs_positions: Optional[Tensor] = None,
- targets_positions: Optional[Tensor] = None,
- inputs_segmentation: Optional[Tensor] = None,
- targets_segmentation: Optional[Tensor] = None,
- decode: bool = False,
- dropout_rate: float = DROPOUT_RATE) -> Tensor:
+ def forward(
+ self,
+ src: Tensor,
+ tgt: Tensor,
+ inputs_positions: Optional[Tensor] = None,
+ targets_positions: Optional[Tensor] = None,
+ inputs_segmentation: Optional[Tensor] = None,
+ targets_segmentation: Optional[Tensor] = None,
+ decode: bool = False,
+ dropout_rate: float = DROPOUT_RATE,
+ ) -> Tensor:
"""
Args:
src: Tensor, shape [batch_size, seq_len]
@@ -175,19 +187,21 @@ def forward(self,
raise RuntimeError('The batch size of src and tgt must be equal.')
memory = self.encoder(
- src,
- inputs_positions=inputs_positions,
- inputs_segmentation=inputs_segmentation,
- dropout_rate=dropout_rate)
+ src,
+ inputs_positions=inputs_positions,
+ inputs_segmentation=inputs_segmentation,
+ dropout_rate=dropout_rate,
+ )
output = self.decoder(
- tgt,
- memory,
- src, # just for calculating the padding mask
- targets_positions=targets_positions,
- inputs_segmentation=inputs_segmentation,
- targets_segmentation=targets_segmentation,
- decode=decode,
- dropout_rate=dropout_rate)
+ tgt,
+ memory,
+ src, # just for calculating the padding mask
+ targets_positions=targets_positions,
+ inputs_segmentation=inputs_segmentation,
+ targets_segmentation=targets_segmentation,
+ decode=decode,
+ dropout_rate=dropout_rate,
+ )
return output
@@ -210,26 +224,32 @@ class TransformerEncoder(nn.Module):
>>> src = torch.rand(10, 32, 512)
>>> out = transformer_encoder(src)
"""
+
__constants__ = ['norm']
- def __init__(self,
- encoder_layer,
- num_layers,
- norm=None,
- enable_nested_tensor=True,
- mask_check=True):
+ def __init__(
+ self,
+ encoder_layer,
+ num_layers,
+ norm=None,
+ enable_nested_tensor=True,
+ mask_check=True,
+ ):
super().__init__()
self.layers = nn.ModuleList(
- [copy.deepcopy(encoder_layer) for _ in range(num_layers)])
+ [copy.deepcopy(encoder_layer) for _ in range(num_layers)]
+ )
self.num_layers = num_layers
self.norm = norm
self.enable_nested_tensor = enable_nested_tensor
self.mask_check = mask_check
- def forward(self,
- src: Tensor,
- mask: Optional[Tensor] = None,
- dropout_rate: Optional[float] = 0.0) -> Tensor:
+ def forward(
+ self,
+ src: Tensor,
+ mask: Optional[Tensor] = None,
+ dropout_rate: Optional[float] = 0.0,
+ ) -> Tensor:
"""Pass the input through the encoder layers in turn.
Args:
@@ -247,7 +267,7 @@ def forward(self,
output = mod(output, src_mask=mask, dropout_rate=dropout_rate)
if convert_to_nested:
- output = output.to_padded_tensor(0.)
+ output = output.to_padded_tensor(0.0)
if self.norm is not None:
output = self.norm(output)
@@ -256,39 +276,42 @@ def forward(self,
class Encoder(nn.Module):
-
- def __init__(self,
- d_model: int = 1024,
- nhead: int = 16,
- d_hid: int = 1024,
- nlayers: int = 6,
- activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
- glu: bool = False,
- layer_norm_eps: float = 1e-6,
- attention_temp: float = 1.0,
- pre_ln: bool = True):
+ def __init__(
+ self,
+ d_model: int = 1024,
+ nhead: int = 16,
+ d_hid: int = 1024,
+ nlayers: int = 6,
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
+ glu: bool = False,
+ layer_norm_eps: float = 1e-6,
+ attention_temp: float = 1.0,
+ pre_ln: bool = True,
+ ):
super().__init__()
self.nhead = nhead
self.shared_embedding = None
self.pos_encoder = None
encoder_layer = TransformerEncoderLayer(
- d_model,
- nhead,
- d_hid,
- activation=activation,
- glu=glu,
- layer_norm_eps=layer_norm_eps,
- attention_temp=attention_temp,
- pre_ln=pre_ln)
- encoder_norm = (
- nn.LayerNorm(d_model, eps=layer_norm_eps) if pre_ln else None)
+ d_model,
+ nhead,
+ d_hid,
+ activation=activation,
+ glu=glu,
+ layer_norm_eps=layer_norm_eps,
+ attention_temp=attention_temp,
+ pre_ln=pre_ln,
+ )
+ encoder_norm = nn.LayerNorm(d_model, eps=layer_norm_eps) if pre_ln else None
self.encoder = TransformerEncoder(encoder_layer, nlayers, encoder_norm)
- def forward(self,
- src: Tensor,
- inputs_positions: Optional[Tensor] = None,
- inputs_segmentation: Optional[Tensor] = None,
- dropout_rate: Optional[float] = 0.0) -> Tensor:
+ def forward(
+ self,
+ src: Tensor,
+ inputs_positions: Optional[Tensor] = None,
+ inputs_segmentation: Optional[Tensor] = None,
+ dropout_rate: Optional[float] = 0.0,
+ ) -> Tensor:
src = src.to(torch.int)
src_mask = make_src_mask(src, inputs_segmentation, self.nhead)
src = self.shared_embedding(src)
@@ -298,67 +321,73 @@ def forward(self,
class Decoder(nn.Module):
-
- def __init__(self,
- d_model: int = 1024,
- nhead: int = 16,
- d_hid: int = 1024,
- nlayers: int = 6,
- activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
- glu: bool = False,
- layer_norm_eps: float = 1e-6,
- attention_temp: float = 1.0,
- pre_ln: bool = True):
+ def __init__(
+ self,
+ d_model: int = 1024,
+ nhead: int = 16,
+ d_hid: int = 1024,
+ nlayers: int = 6,
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
+ glu: bool = False,
+ layer_norm_eps: float = 1e-6,
+ attention_temp: float = 1.0,
+ pre_ln: bool = True,
+ ):
super().__init__()
self.nhead = nhead
self.shared_embedding = None
self.pos_encoder = None
- self.decoder = TransformerDecoder(d_model,
- nhead,
- d_hid,
- activation,
- glu,
- layer_norm_eps,
- nlayers,
- attention_temp,
- pre_ln)
+ self.decoder = TransformerDecoder(
+ d_model,
+ nhead,
+ d_hid,
+ activation,
+ glu,
+ layer_norm_eps,
+ nlayers,
+ attention_temp,
+ pre_ln,
+ )
def forward(
- self,
- tgt: Tensor,
- memory: Tensor,
- src: Tensor, # just for calculating the padding mask
- targets_positions: Optional[Tensor] = None,
- inputs_segmentation: Optional[Tensor] = None,
- targets_segmentation: Optional[Tensor] = None,
- decode: bool = False,
- max_len: Optional[int] = None,
- cache: Optional[dict] = None,
- dropout_rate: Optional[float] = 0.0) -> Any:
+ self,
+ tgt: Tensor,
+ memory: Tensor,
+ src: Tensor, # just for calculating the padding mask
+ targets_positions: Optional[Tensor] = None,
+ inputs_segmentation: Optional[Tensor] = None,
+ targets_segmentation: Optional[Tensor] = None,
+ decode: bool = False,
+ max_len: Optional[int] = None,
+ cache: Optional[dict] = None,
+ dropout_rate: Optional[float] = 0.0,
+ ) -> Any:
tgt = tgt.to(torch.int)
tgt_mask, memory_mask = make_tgt_and_memory_mask(
- tgt, src, inputs_segmentation, targets_segmentation,
- decode, self.nhead)
+ tgt, src, inputs_segmentation, targets_segmentation, decode, self.nhead
+ )
if not decode:
tgt = shift_right(tgt)
tgt = self.shared_embedding(tgt)
tgt = self.pos_encoder(
- tgt,
- targets_positions,
- decode=decode,
- cache=cache,
- dropout_rate=dropout_rate)
+ tgt,
+ targets_positions,
+ decode=decode,
+ cache=cache,
+ dropout_rate=dropout_rate,
+ )
if decode:
tgt, cache = tgt
output = self.decoder(
- tgt,
- memory,
- tgt_mask=tgt_mask,
- memory_mask=memory_mask,
- decode=decode,
- max_len=max_len,
- cache=cache,
- dropout_rate=dropout_rate)
+ tgt,
+ memory,
+ tgt_mask=tgt_mask,
+ memory_mask=memory_mask,
+ decode=decode,
+ max_len=max_len,
+ cache=cache,
+ dropout_rate=dropout_rate,
+ )
if decode:
output, cache = output
normalize = math.sqrt(output.shape[-1])
@@ -369,7 +398,6 @@ def forward(
class PositionalEncoding(nn.Module):
-
def __init__(self, d_model: int, max_len: int = 256):
super().__init__()
@@ -377,17 +405,17 @@ def __init__(self, d_model: int, max_len: int = 256):
scale_factor = -math.log(10000.0) / (d_model // 2 - 1)
div_term = torch.exp(torch.arange(d_model // 2) * scale_factor)
pe = torch.zeros(1, max_len, d_model)
- pe[0, :, :d_model // 2] = torch.sin(position * div_term)
- pe[0, :, d_model // 2:2 * (d_model // 2)] = torch.cos(position * div_term)
+ pe[0, :, : d_model // 2] = torch.sin(position * div_term)
+ pe[0, :, d_model // 2 : 2 * (d_model // 2)] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(
- self,
- x: Tensor,
- inputs_positions: Optional[Tensor] = None,
- decode: bool = False,
- cache: Optional[Dict[str, Dict[str, Tensor]]] = None,
- dropout_rate: Optional[float] = 0.0
+ self,
+ x: Tensor,
+ inputs_positions: Optional[Tensor] = None,
+ decode: bool = False,
+ cache: Optional[Dict[str, Dict[str, Tensor]]] = None,
+ dropout_rate: Optional[float] = 0.0,
) -> Union[Tensor, Tuple[Tensor, Dict[str, Dict[str, Tensor]]]]:
"""
Args:
@@ -404,17 +432,18 @@ def forward(
name = self._get_name()
if cache is None:
cache = {
- name: {
- 'cache_index':
- torch.tensor(0, dtype=torch.long, device=self.pe.device),
- },
+ name: {
+ 'cache_index': torch.tensor(
+ 0, dtype=torch.long, device=self.pe.device
+ ),
+ },
}
pe = self.pe[0, cache[name]['cache_index'], :]
cache[name]['cache_index'] += 1
return F.dropout(x + pe, dropout_rate, self.training), cache
if inputs_positions is None:
# normal unpacked case:
- pe = self.pe[:, :x.size(1), :]
+ pe = self.pe[:, : x.size(1), :]
else:
# for packed data we need to use known position indices:
pe = self.pe[0, inputs_positions, :]
@@ -449,28 +478,32 @@ class TransformerEncoderLayer(nn.Module):
>>> src = torch.rand(32, 10, 512)
>>> out = encoder_layer(src)
"""
+
__constants__ = ['pre_ln']
- def __init__(self,
- d_model: int = 1024,
- nhead: int = 16,
- dim_feedforward: int = 1024,
- activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
- glu: bool = False,
- layer_norm_eps: float = 1e-6,
- attention_temp: float = 1.0,
- pre_ln: bool = True,
- device=None,
- dtype=None) -> None:
+ def __init__(
+ self,
+ d_model: int = 1024,
+ nhead: int = 16,
+ dim_feedforward: int = 1024,
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
+ glu: bool = False,
+ layer_norm_eps: float = 1e-6,
+ attention_temp: float = 1.0,
+ pre_ln: bool = True,
+ device=None,
+ dtype=None,
+ ) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.self_attn = MultiheadAttention(
- d_model,
- nhead,
- self_attn=True,
- attention_temp=attention_temp,
- bias=False,
- **factory_kwargs)
+ d_model,
+ nhead,
+ self_attn=True,
+ attention_temp=attention_temp,
+ bias=False,
+ **factory_kwargs,
+ )
# Implementation of Feedforward model.
self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
@@ -485,10 +518,12 @@ def __init__(self,
self.activation = activation
- def forward(self,
- src: Tensor,
- src_mask: Optional[Tensor] = None,
- dropout_rate: Optional[float] = 0.0) -> Tensor:
+ def forward(
+ self,
+ src: Tensor,
+ src_mask: Optional[Tensor] = None,
+ dropout_rate: Optional[float] = 0.0,
+ ) -> Tensor:
r"""Pass the input through the encoder layer.
Args:
@@ -509,17 +544,19 @@ def forward(self,
return x
# Self-attention block:
- def _sa_block(self,
- x: Tensor,
- attn_mask: Optional[Tensor],
- dropout_rate: Optional[float] = 0.0) -> Tensor:
+ def _sa_block(
+ self,
+ x: Tensor,
+ attn_mask: Optional[Tensor],
+ dropout_rate: Optional[float] = 0.0,
+ ) -> Tensor:
x, _ = self.self_attn(x, attn_mask=attn_mask, dropout_rate=dropout_rate)
return F.dropout(x, dropout_rate, training=self.training)
# Feed forward block:
- def _ff_block(self,
- inputs: Tensor,
- dropout_rate: Optional[float] = 0.0) -> Tensor:
+ def _ff_block(
+ self, inputs: Tensor, dropout_rate: Optional[float] = 0.0
+ ) -> Tensor:
x = self.activation(self.linear1(inputs))
if self.glu:
y = self.linear_glu(inputs)
@@ -548,42 +585,51 @@ class TransformerDecoder(nn.Module):
>>> tgt = torch.rand(20, 32, 512)
>>> out = transformer_decoder(tgt, memory)
"""
+
__constants__ = ['norm']
- def __init__(self,
- d_model,
- nhead,
- d_hid,
- activation,
- glu,
- layer_norm_eps,
- num_layers,
- attention_temp,
- pre_ln):
+ def __init__(
+ self,
+ d_model,
+ nhead,
+ d_hid,
+ activation,
+ glu,
+ layer_norm_eps,
+ num_layers,
+ attention_temp,
+ pre_ln,
+ ):
super().__init__()
- self.layers = nn.ModuleList([
+ self.layers = nn.ModuleList(
+ [
TransformerDecoderLayer(
- d_model,
- nhead,
- d_hid,
- activation,
- glu,
- layer_norm_eps=layer_norm_eps,
- attention_temp=attention_temp,
- pre_ln=pre_ln) for _ in range(num_layers)
- ])
+ d_model,
+ nhead,
+ d_hid,
+ activation,
+ glu,
+ layer_norm_eps=layer_norm_eps,
+ attention_temp=attention_temp,
+ pre_ln=pre_ln,
+ )
+ for _ in range(num_layers)
+ ]
+ )
self.num_layers = num_layers
- self.norm = (nn.LayerNorm(d_model, eps=layer_norm_eps) if pre_ln else None)
-
- def forward(self,
- tgt: Tensor,
- memory: Tensor,
- tgt_mask: Optional[Tensor] = None,
- memory_mask: Optional[Tensor] = None,
- decode: bool = False,
- max_len: Optional[int] = None,
- cache: Optional[dict] = None,
- dropout_rate: Optional[float] = 0.0) -> Any:
+ self.norm = nn.LayerNorm(d_model, eps=layer_norm_eps) if pre_ln else None
+
+ def forward(
+ self,
+ tgt: Tensor,
+ memory: Tensor,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ decode: bool = False,
+ max_len: Optional[int] = None,
+ cache: Optional[dict] = None,
+ dropout_rate: Optional[float] = 0.0,
+ ) -> Any:
r"""Pass the inputs (and mask) through the decoder layer in turn.
Args:
tgt: the sequence to the decoder (required).
@@ -600,15 +646,16 @@ def forward(self,
for idx, mod in enumerate(self.layers):
output, cache = mod(
- output,
- memory,
- tgt_mask=tgt_mask,
- memory_mask=memory_mask,
- decode=decode,
- max_len=max_len,
- cache=cache,
- index=idx,
- dropout_rate=dropout_rate)
+ output,
+ memory,
+ tgt_mask=tgt_mask,
+ memory_mask=memory_mask,
+ decode=decode,
+ max_len=max_len,
+ cache=cache,
+ index=idx,
+ dropout_rate=dropout_rate,
+ )
if self.norm is not None:
output = self.norm(output)
@@ -647,43 +694,48 @@ class TransformerDecoderLayer(nn.Module):
>>> tgt = torch.rand(32, 20, 512)
>>> out = decoder_layer(tgt, memory)
"""
+
__constants__ = ['pre_ln']
- def __init__(self,
- d_model: int = 1024,
- nhead: int = 16,
- dim_feedforward: int = 1024,
- activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
- glu: bool = False,
- layer_norm_eps: float = 1e-6,
- pre_ln: bool = True,
- attention_temp: float = 1.0,
- device=None,
- dtype=None) -> None:
+ def __init__(
+ self,
+ d_model: int = 1024,
+ nhead: int = 16,
+ dim_feedforward: int = 1024,
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
+ glu: bool = False,
+ layer_norm_eps: float = 1e-6,
+ pre_ln: bool = True,
+ attention_temp: float = 1.0,
+ device=None,
+ dtype=None,
+ ) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.self_attn = MultiheadAttention(
- d_model,
- nhead,
- self_attn=True,
- attention_temp=attention_temp,
- bias=False,
- **factory_kwargs)
+ d_model,
+ nhead,
+ self_attn=True,
+ attention_temp=attention_temp,
+ bias=False,
+ **factory_kwargs,
+ )
self.multihead_attn = MultiheadAttention(
- d_model,
- nhead,
- self_attn=False,
- attention_temp=attention_temp,
- bias=False,
- **factory_kwargs)
+ d_model,
+ nhead,
+ self_attn=False,
+ attention_temp=attention_temp,
+ bias=False,
+ **factory_kwargs,
+ )
# Implementation of Feedforward model.
self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
self.glu = glu
if self.glu:
- self.linear_glu = nn.Linear(dim_feedforward,
- dim_feedforward,
- **factory_kwargs)
+ self.linear_glu = nn.Linear(
+ dim_feedforward, dim_feedforward, **factory_kwargs
+ )
self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
self.pre_ln = pre_ln
@@ -694,16 +746,17 @@ def __init__(self,
self.activation = activation
def forward( # pylint: disable=arguments-renamed
- self,
- tgt: Tensor,
- memory: Tensor,
- tgt_mask: Optional[Tensor] = None,
- memory_mask: Optional[Tensor] = None,
- decode: bool = False,
- max_len: Optional[int] = None,
- cache: Optional[dict] = None,
- index: Optional[int] = None,
- dropout_rate: Optional[float] = 0.0) -> Any:
+ self,
+ tgt: Tensor,
+ memory: Tensor,
+ tgt_mask: Optional[Tensor] = None,
+ memory_mask: Optional[Tensor] = None,
+ decode: bool = False,
+ max_len: Optional[int] = None,
+ cache: Optional[dict] = None,
+ index: Optional[int] = None,
+ dropout_rate: Optional[float] = 0.0,
+ ) -> Any:
r"""Pass the inputs (and mask) through the decoder layer.
Args:
tgt: the sequence to the decoder layer (required).
@@ -721,25 +774,27 @@ def forward( # pylint: disable=arguments-renamed
x = tgt
if self.pre_ln:
sa_out, cache = self._sa_block(
- self.norm1(x),
- tgt_mask,
- decode=decode,
- max_len=max_len,
- cache=cache,
- index=index,
- dropout_rate=dropout_rate)
+ self.norm1(x),
+ tgt_mask,
+ decode=decode,
+ max_len=max_len,
+ cache=cache,
+ index=index,
+ dropout_rate=dropout_rate,
+ )
x = x + sa_out
x = x + self._mha_block(self.norm2(x), memory, memory_mask, dropout_rate)
x = x + self._ff_block(self.norm3(x), dropout_rate)
else:
sa_out, cache = self._sa_block(
- x,
- tgt_mask,
- decode=decode,
- max_len=max_len,
- cache=cache,
- index=index,
- dropout_rate=dropout_rate)
+ x,
+ tgt_mask,
+ decode=decode,
+ max_len=max_len,
+ cache=cache,
+ index=index,
+ dropout_rate=dropout_rate,
+ )
x = self.norm1(x + sa_out)
x = self.norm2(x + self._mha_block(x, memory, memory_mask, dropout_rate))
x = self.norm3(x + self._ff_block(x, dropout_rate))
@@ -748,41 +803,43 @@ def forward( # pylint: disable=arguments-renamed
# Self-attention block:
def _sa_block( # pylint: disable=arguments-renamed
- self,
- x: Tensor,
- attn_mask: Optional[Tensor],
- decode: bool = False,
- max_len: Optional[int] = None,
- cache: Optional[dict] = None,
- index: Optional[int] = None,
- dropout_rate: Optional[float] = 0.0) -> Any:
+ self,
+ x: Tensor,
+ attn_mask: Optional[Tensor],
+ decode: bool = False,
+ max_len: Optional[int] = None,
+ cache: Optional[dict] = None,
+ index: Optional[int] = None,
+ dropout_rate: Optional[float] = 0.0,
+ ) -> Any:
x, cache = self.self_attn(
- x,
- attn_mask=attn_mask,
- decode=decode,
- max_len=max_len,
- cache=cache,
- index=index,
- dropout_rate=dropout_rate)
+ x,
+ attn_mask=attn_mask,
+ decode=decode,
+ max_len=max_len,
+ cache=cache,
+ index=index,
+ dropout_rate=dropout_rate,
+ )
return F.dropout(x, dropout_rate, self.training), cache
# Multihead attention block:
- def _mha_block(self,
- x: Tensor,
- mem: Tensor,
- attn_mask: Optional[Tensor],
- dropout_rate: Optional[float] = 0.0) -> Tensor:
+ def _mha_block(
+ self,
+ x: Tensor,
+ mem: Tensor,
+ attn_mask: Optional[Tensor],
+ dropout_rate: Optional[float] = 0.0,
+ ) -> Tensor:
x, _ = self.multihead_attn(
- x,
- mem,
- attn_mask=attn_mask,
- dropout_rate=dropout_rate)
+ x, mem, attn_mask=attn_mask, dropout_rate=dropout_rate
+ )
return F.dropout(x, dropout_rate, self.training)
# Feed forward block.
- def _ff_block(self,
- inputs: Tensor,
- dropout_rate: Optional[float] = 0.0) -> Tensor:
+ def _ff_block(
+ self, inputs: Tensor, dropout_rate: Optional[float] = 0.0
+ ) -> Tensor:
x = self.activation(self.linear1(inputs))
if self.glu:
y = self.linear_glu(inputs)
@@ -815,33 +872,38 @@ class MultiheadAttention(nn.Module):
>>> attn_output, cache = multihead_attn(x)
"""
- def __init__(self,
- embed_dim: int,
- num_heads: int,
- self_attn: bool = True,
- attention_temp: float = 1.0,
- bias: bool = False,
- device: Optional[torch.device] = None,
- dtype: Optional[torch.dtype] = None) -> None:
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ self_attn: bool = True,
+ attention_temp: float = 1.0,
+ bias: bool = False,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ) -> None:
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.self_attn = self_attn
self.head_dim = embed_dim // num_heads
self.attention_temp = attention_temp
- assert self.head_dim * num_heads == self.embed_dim, \
- 'embed_dim must be divisible by num_heads.'
+ assert self.head_dim * num_heads == self.embed_dim, (
+ 'embed_dim must be divisible by num_heads.'
+ )
factory_kwargs = {'device': device, 'dtype': dtype}
if self_attn:
# Self-attention.
self.in_proj = nn.Linear(
- embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
+ embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
+ )
else:
# Encoder-decoder attention.
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
self.kv_proj = nn.Linear(
- embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs)
+ embed_dim, 2 * embed_dim, bias=bias, **factory_kwargs
+ )
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
self._reset_parameters()
@@ -855,15 +917,15 @@ def _reset_parameters(self):
normal_(module.bias, std=1e-6)
def forward(
- self,
- x: Tensor,
- mem: Optional[Tensor] = None,
- attn_mask: Optional[Tensor] = None,
- decode: bool = False,
- max_len: Optional[int] = None,
- cache: Optional[dict] = None,
- index: Optional[int] = None,
- dropout_rate: Optional[float] = 0.0
+ self,
+ x: Tensor,
+ mem: Optional[Tensor] = None,
+ attn_mask: Optional[Tensor] = None,
+ decode: bool = False,
+ max_len: Optional[int] = None,
+ cache: Optional[dict] = None,
+ index: Optional[int] = None,
+ dropout_rate: Optional[float] = 0.0,
) -> Any: # TODO: (nico) remove default?!
r"""
Args:
@@ -872,7 +934,7 @@ def forward(
attention mechanism. See "Attention Is All You Need" for more details.
mem: Batch of input sequences of shape
(batch size, sequence length, embedding dimensionality) for
- encoder-decoder attention. See "Attention Is All You Need" for more
+ encoder-decoder attention. See "Attention Is All You Need" for more
details.
attn_mask: If specified, a 2D or 3D mask preventing attention to certain
positions. Must be of shape :math:`(L, S)` or
@@ -915,16 +977,13 @@ def forward(
if decode:
if loc_cache is None:
loc_cache = {
- 'cached_key':
- torch.zeros((bsz, max_len, embed_dim),
- dtype=k.dtype,
- device=k.device),
- 'cached_value':
- torch.zeros((bsz, max_len, embed_dim),
- dtype=v.dtype,
- device=v.device),
- 'cache_index':
- torch.tensor(0, dtype=torch.long, device=k.device),
+ 'cached_key': torch.zeros(
+ (bsz, max_len, embed_dim), dtype=k.dtype, device=k.device
+ ),
+ 'cached_value': torch.zeros(
+ (bsz, max_len, embed_dim), dtype=v.dtype, device=v.device
+ ),
+ 'cache_index': torch.tensor(0, dtype=torch.long, device=k.device),
}
cached_key = loc_cache['cached_key']
cached_value = loc_cache['cached_value']
@@ -932,11 +991,13 @@ def forward(
# Shape check of cached keys against query input.
expected_shape = (bsz, 1, embed_dim)
if expected_shape != x.shape:
- raise ValueError('Autoregressive cache shape error, expected query '
- f'shape {expected_shape} instead got {x.shape}.')
+ raise ValueError(
+ 'Autoregressive cache shape error, expected query '
+ f'shape {expected_shape} instead got {x.shape}.'
+ )
# Update key, value caches with our new 1d spatial slices.
- cached_key[:, cache_index:cache_index + 1, :] = k
- cached_value[:, cache_index:cache_index + 1, :] = v
+ cached_key[:, cache_index : cache_index + 1, :] = k
+ cached_value[:, cache_index : cache_index + 1, :] = v
k = cached_key
v = cached_value
cache_index += 1
@@ -946,8 +1007,9 @@ def forward(
# not the remaining zero elements.
if attn_mask is not None:
raise ValueError('Attention mask has to be None for decode == True.')
- attn_mask = (torch.arange(max_len, device=k.device) >=
- cache_index).reshape(1, max_len)
+ attn_mask = (
+ torch.arange(max_len, device=k.device) >= cache_index
+ ).reshape(1, max_len)
# Update sequence length to account for complete sequence.
seq_len = k.size(1)
@@ -959,17 +1021,21 @@ def forward(
# Check dtype and shape of attention mask.
if not decode and attn_mask is not None:
- assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, \
- f'Float and bool dtypes are supported, not {attn_mask.dtype}.'
+ assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, (
+ f'Float and bool dtypes are supported, not {attn_mask.dtype}.'
+ )
# Ensure attn_mask's dim is 3.
if attn_mask.dim() == 3:
correct_3d_size = (bsz * self.num_heads, tgt_len, seq_len)
if attn_mask.shape != correct_3d_size:
- raise RuntimeError(f'The shape of attn_mask is {attn_mask.shape}, '
- f'but should be {correct_3d_size}.')
+ raise RuntimeError(
+ f'The shape of attn_mask is {attn_mask.shape}, '
+ f'but should be {correct_3d_size}.'
+ )
else:
raise RuntimeError(
- f"attn_mask's dimension {attn_mask.dim()} is not supported")
+ f"attn_mask's dimension {attn_mask.dim()} is not supported"
+ )
# Reshape attention mask to be consistent with q, k, v.
attn_mask = attn_mask.view(bsz, self.num_heads, tgt_len, seq_len)
@@ -985,10 +1051,12 @@ def forward(
# Calculate attention.
q = self.attention_temp * q
attn_output = torch.nn.functional.scaled_dot_product_attention(
- q, k, v, attn_mask, attn_dropout_rate)
+ q, k, v, attn_mask, attn_dropout_rate
+ )
# Rearrange for output projection.
- attn_output = attn_output.transpose(1, 2).contiguous().view(
- bsz, tgt_len, embed_dim)
+ attn_output = (
+ attn_output.transpose(1, 2).contiguous().view(bsz, tgt_len, embed_dim)
+ )
# Output projection.
attn_output = self.out_proj(attn_output)
diff --git a/algoperf/workloads/wmt/wmt_pytorch/workload.py b/algoperf/workloads/wmt/wmt_pytorch/workload.py
index 4ec816f2f..53d95d393 100644
--- a/algoperf/workloads/wmt/wmt_pytorch/workload.py
+++ b/algoperf/workloads/wmt/wmt_pytorch/workload.py
@@ -3,21 +3,18 @@
import contextlib
from typing import Any, Dict, Optional, Tuple
-from absl import logging
import jax
import tensorflow as tf
import torch
import torch.distributed as dist
-from torch.nn import DataParallel as DP
import torch.nn.functional as F
+from absl import logging
+from torch.nn import DataParallel as DP
from torch.nn.parallel import DistributedDataParallel as DDP
-from algoperf import param_utils
-from algoperf import pytorch_utils
-from algoperf import spec
+from algoperf import param_utils, pytorch_utils, spec
from algoperf.workloads.wmt import bleu
-from algoperf.workloads.wmt.wmt_pytorch import decode
-from algoperf.workloads.wmt.wmt_pytorch import models
+from algoperf.workloads.wmt.wmt_pytorch import decode, models
from algoperf.workloads.wmt.wmt_pytorch.models import Transformer
from algoperf.workloads.wmt.workload import BaseWmtWorkload
@@ -28,11 +25,12 @@ class WmtWorkload(BaseWmtWorkload):
"""WMT PyTorch workload."""
def compute_weighted_cross_entropy(
- self,
- logits: spec.Tensor,
- targets: spec.Tensor,
- weights: Optional[spec.Tensor] = None,
- label_smoothing: float = 0.1) -> Dict[str, spec.Tensor]: # differentiable
+ self,
+ logits: spec.Tensor,
+ targets: spec.Tensor,
+ weights: Optional[spec.Tensor] = None,
+ label_smoothing: float = 0.1,
+ ) -> Dict[str, spec.Tensor]: # differentiable
"""Compute weighted cross entropy and entropy for log probs and targets.
Args:
@@ -47,11 +45,14 @@ def compute_weighted_cross_entropy(
valid examples in batch, 'per_example': 1-d array of per-example losses}
"""
if logits.ndim != targets.ndim + 1:
- raise ValueError(f'Incorrect shapes. Got shape {logits.shape} logits and '
- f'{targets.shape} targets.')
+ raise ValueError(
+ f'Incorrect shapes. Got shape {logits.shape} logits and '
+ f'{targets.shape} targets.'
+ )
loss_fn = torch.nn.CrossEntropyLoss(
- reduction='none', label_smoothing=label_smoothing)
+ reduction='none', label_smoothing=label_smoothing
+ )
if N_GPUS > 1 and not USE_PYTORCH_DDP:
loss_fn = DP(loss_fn)
@@ -60,24 +61,27 @@ def compute_weighted_cross_entropy(
if weights is None:
weights = torch.ones_like(targets)
per_example_losses = torch.where(
- weights.to(torch.bool), per_example_losses, 0.)
+ weights.to(torch.bool), per_example_losses, 0.0
+ )
summed_loss = per_example_losses.sum()
n_valid_examples = weights.sum()
return {
- 'summed': summed_loss,
- 'n_valid_examples': n_valid_examples,
- 'per_example': per_example_losses,
+ 'summed': summed_loss,
+ 'n_valid_examples': n_valid_examples,
+ 'per_example': per_example_losses,
}
# Primary eval / decode step functions.
# ----------------------------------------------------------------------------
@torch.no_grad()
- def predict_step(self,
- inputs: spec.Tensor,
- params: spec.ParameterContainer,
- eos_id: int,
- max_decode_len: int,
- beam_size: int = 4) -> spec.Tensor:
+ def predict_step(
+ self,
+ inputs: spec.Tensor,
+ params: spec.ParameterContainer,
+ eos_id: int,
+ max_decode_len: int,
+ beam_size: int = 4,
+ ) -> spec.Tensor:
"""Predict translation with fast decoding beam search on a batch."""
# params = params.module if isinstance(params, (DP, DDP)) else params
if hasattr(params, 'module'):
@@ -86,8 +90,8 @@ def predict_step(self,
if hasattr(params, '_modules'):
params = params._modules
- encoder = params["encoder"]
- decoder = params["decoder"]
+ encoder = params['encoder']
+ decoder = params['decoder']
else:
encoder = params.encoder
decoder = params.decoder
@@ -98,21 +102,23 @@ def predict_step(self,
decoder = DP(decoder)
encoded_inputs = torch.repeat_interleave(
- encoder(inputs), repeats=beam_size, dim=0)
+ encoder(inputs), repeats=beam_size, dim=0
+ )
raw_inputs = torch.repeat_interleave(inputs, repeats=beam_size, dim=0)
def tokens_ids_to_logits(
- flat_ids: spec.Tensor, flat_cache: Dict[str, spec.Tensor]
+ flat_ids: spec.Tensor, flat_cache: Dict[str, spec.Tensor]
) -> Tuple[spec.Tensor, Dict[str, spec.Tensor]]:
"""Token slice to logits from decoder model."""
# --> [batch * beam, 1, vocab]
flat_logits, new_flat_cache = decoder(
- flat_ids,
- encoded_inputs,
- raw_inputs,
- decode=True,
- max_len=max_decode_len,
- cache=flat_cache)
+ flat_ids,
+ encoded_inputs,
+ raw_inputs,
+ decode=True,
+ max_len=max_decode_len,
+ cache=flat_cache,
+ )
# Remove singleton sequence-length dimension:
# [batch * beam, 1, vocab] --> [batch * beam, vocab]
flat_logits = flat_logits.squeeze(dim=1)
@@ -121,24 +127,27 @@ def tokens_ids_to_logits(
# Using the above-defined single-step decoder function, run a
# beam search over possible sequences given input encoding.
beam_seqs, _ = decode.beam_search(
- inputs,
- None,
- tokens_ids_to_logits,
- beam_size=beam_size,
- alpha=0.6,
- eos_id=eos_id,
- max_decode_len=max_decode_len)
+ inputs,
+ None,
+ tokens_ids_to_logits,
+ beam_size=beam_size,
+ alpha=0.6,
+ eos_id=eos_id,
+ max_decode_len=max_decode_len,
+ )
# Beam search returns [n_batch, n_beam, n_length + 1] with beam dimension
# sorted in increasing order of log-probability.
# Return the highest scoring beam sequence, drop first dummy 0 token.
return beam_seqs[:, -1, 1:]
- def translate_and_calculate_bleu(self,
- params: spec.ParameterContainer,
- ds_iter: tf.data.Dataset,
- num_batches: int,
- max_predict_length: int):
+ def translate_and_calculate_bleu(
+ self,
+ params: spec.ParameterContainer,
+ ds_iter: tf.data.Dataset,
+ num_batches: int,
+ max_predict_length: int,
+ ):
"""Translates the `ds_iter` and calculates the BLEU score."""
logging.info('Translating evaluation dataset.')
references, predictions = [], []
@@ -146,10 +155,9 @@ def translate_and_calculate_bleu(self,
pred_batch = next(ds_iter)
inputs = pred_batch['inputs']
targets = pred_batch['targets']
- predicted = self.predict_step(inputs,
- params,
- decode.EOS_ID,
- max_predict_length)
+ predicted = self.predict_step(
+ inputs, params, decode.EOS_ID, max_predict_length
+ )
# Find actual batch size, ignoring the potential padding.
weights = pred_batch.get('weights')
@@ -177,10 +185,11 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
raise ValueError(f'Unknown activation function {self.activation}.')
model = Transformer(
- pre_ln=self.pre_ln,
- attention_temp=self.attention_temp,
- activation=activation,
- glu=self.glu)
+ pre_ln=self.pre_ln,
+ attention_temp=self.attention_temp,
+ activation=activation,
+ glu=self.glu,
+ )
self._param_shapes = param_utils.pytorch_param_shapes(model)
self._param_types = param_utils.pytorch_param_types(self._param_shapes)
model.to(DEVICE)
@@ -195,14 +204,14 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key == 'shared_embedding.weight'
def model_fn(
- self,
- params: spec.ParameterContainer,
- augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
- model_state: spec.ModelAuxiliaryState,
- mode: spec.ForwardPassMode,
- rng: spec.RandomState,
- update_batch_norm: bool,
- dropout_rate: float = models.DROPOUT_RATE
+ self,
+ params: spec.ParameterContainer,
+ augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
+ model_state: spec.ModelAuxiliaryState,
+ mode: spec.ForwardPassMode,
+ rng: spec.RandomState,
+ update_batch_norm: bool,
+ dropout_rate: float = models.DROPOUT_RATE,
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
del model_state
del rng
@@ -214,44 +223,53 @@ def model_fn(
model.eval()
contexts = {
- spec.ForwardPassMode.EVAL: torch.no_grad,
- spec.ForwardPassMode.TRAIN: contextlib.nullcontext,
+ spec.ForwardPassMode.EVAL: torch.no_grad,
+ spec.ForwardPassMode.TRAIN: contextlib.nullcontext,
}
with contexts[mode]():
logits_batch = model(
- src=augmented_and_preprocessed_input_batch['inputs'],
- tgt=augmented_and_preprocessed_input_batch['targets'],
- inputs_positions=augmented_and_preprocessed_input_batch.get(
- 'inputs_position', None),
- targets_positions=augmented_and_preprocessed_input_batch.get(
- 'targets_position', None),
- inputs_segmentation=augmented_and_preprocessed_input_batch.get(
- 'inputs_segmentation', None),
- targets_segmentation=augmented_and_preprocessed_input_batch.get(
- 'targets_segmentation', None),
- dropout_rate=dropout_rate)
+ src=augmented_and_preprocessed_input_batch['inputs'],
+ tgt=augmented_and_preprocessed_input_batch['targets'],
+ inputs_positions=augmented_and_preprocessed_input_batch.get(
+ 'inputs_position', None
+ ),
+ targets_positions=augmented_and_preprocessed_input_batch.get(
+ 'targets_position', None
+ ),
+ inputs_segmentation=augmented_and_preprocessed_input_batch.get(
+ 'inputs_segmentation', None
+ ),
+ targets_segmentation=augmented_and_preprocessed_input_batch.get(
+ 'targets_segmentation', None
+ ),
+ dropout_rate=dropout_rate,
+ )
return logits_batch, None
- def _build_input_queue(self,
- data_rng: jax.random.PRNGKey,
- split: str,
- data_dir: str,
- global_batch_size: int,
- num_batches: Optional[int] = None,
- repeat_final_dataset: bool = False):
+ def _build_input_queue(
+ self,
+ data_rng: jax.random.PRNGKey,
+ split: str,
+ data_dir: str,
+ global_batch_size: int,
+ num_batches: Optional[int] = None,
+ repeat_final_dataset: bool = False,
+ ):
per_device_batch_size = int(global_batch_size / N_GPUS)
n_inputs = 7 if split == 'train' else 3
# The input pipeline has to be created in all processes, because
# self._tokenizer has to be available in every process.
- np_iter = super()._build_input_queue(data_rng,
- split,
- data_dir,
- global_batch_size,
- num_batches,
- repeat_final_dataset)
+ np_iter = super()._build_input_queue(
+ data_rng,
+ split,
+ data_dir,
+ global_batch_size,
+ num_batches,
+ repeat_final_dataset,
+ )
# We only need np_iter in one Python process.
if RANK != 0:
del np_iter
@@ -266,14 +284,15 @@ def _build_input_queue(self,
tensor = torch.as_tensor(value, dtype=torch.int64, device=DEVICE)
tensor_list.append(tensor)
batch[key] = (
- tensor[0] if USE_PYTORCH_DDP else tensor.view(
- -1, value.shape[-1]))
+ tensor[0] if USE_PYTORCH_DDP else tensor.view(-1, value.shape[-1])
+ )
# Send batch to other devices when using DDP.
if USE_PYTORCH_DDP:
# During eval, the batch size of the remainder might be different.
if split != 'train':
per_device_batch_size = torch.tensor(
- len(batch['inputs']), dtype=torch.int32, device=DEVICE)
+ len(batch['inputs']), dtype=torch.int32, device=DEVICE
+ )
dist.broadcast(per_device_batch_size, src=0)
# We don't need to broadcast the batch for the device with RANK == 0.
dist.broadcast(torch.stack(tensor_list)[:, 1:].contiguous(), src=0)
@@ -281,25 +300,27 @@ def _build_input_queue(self,
batch = {}
# During eval, the batch size of the remainder might be different.
if split != 'train':
- per_device_batch_size = torch.empty((1,),
- dtype=torch.int32,
- device=DEVICE)
+ per_device_batch_size = torch.empty(
+ (1,), dtype=torch.int32, device=DEVICE
+ )
dist.broadcast(per_device_batch_size, src=0)
# N_GPUS - 1 since we don't broadcast the batch for RANK == 0.
- tensor = torch.empty((n_inputs, N_GPUS - 1, per_device_batch_size, 256),
- dtype=torch.int64,
- device=DEVICE)
+ tensor = torch.empty(
+ (n_inputs, N_GPUS - 1, per_device_batch_size, 256),
+ dtype=torch.int64,
+ device=DEVICE,
+ )
dist.broadcast(tensor, src=0)
# Note that the order of the keys is important.
if split == 'train':
keys = [
- 'inputs',
- 'inputs_position',
- 'inputs_segmentation',
- 'targets',
- 'targets_position',
- 'targets_segmentation',
- 'weights',
+ 'inputs',
+ 'inputs_position',
+ 'inputs_segmentation',
+ 'targets',
+ 'targets_position',
+ 'targets_segmentation',
+ 'weights',
]
# For all eval/test splits.
else:
@@ -309,34 +330,35 @@ def _build_input_queue(self,
batch[key] = tensor[n][RANK - 1]
yield batch
- def eval_step(self,
- params: spec.ParameterContainer,
- batch: Dict[str, spec.Tensor]) -> Dict[str, spec.Tensor]:
+ def eval_step(
+ self, params: spec.ParameterContainer, batch: Dict[str, spec.Tensor]
+ ) -> Dict[str, spec.Tensor]:
"""Calculate evaluation metrics on a batch."""
targets = batch['targets']
weights = batch['weights']
logits, _ = self.model_fn(
- params,
- batch,
- mode=spec.ForwardPassMode.EVAL,
- model_state=None,
- rng=None,
- update_batch_norm=False)
- summed_loss = self.compute_weighted_cross_entropy(logits,
- targets,
- weights,
- 0.0)['summed']
+ params,
+ batch,
+ mode=spec.ForwardPassMode.EVAL,
+ model_state=None,
+ rng=None,
+ update_batch_norm=False,
+ )
+ summed_loss = self.compute_weighted_cross_entropy(
+ logits, targets, weights, 0.0
+ )['summed']
acc_sum, weight_sum = self.compute_weighted_accuracy(
- logits, targets, weights)
+ logits, targets, weights
+ )
return {
- 'loss': summed_loss,
- 'accuracy': acc_sum,
- 'denominator': weight_sum,
+ 'loss': summed_loss,
+ 'accuracy': acc_sum,
+ 'denominator': weight_sum,
}
def _normalize_eval_metrics(
- self, num_examples: int, total_metrics: Dict[str,
- Any]) -> Dict[str, float]:
+ self, num_examples: int, total_metrics: Dict[str, Any]
+ ) -> Dict[str, float]:
"""Normalize eval metrics."""
del num_examples
if USE_PYTORCH_DDP:
diff --git a/algoperf/workloads/wmt/workload.py b/algoperf/workloads/wmt/workload.py
index 51b33373d..40e4262dd 100644
--- a/algoperf/workloads/wmt/workload.py
+++ b/algoperf/workloads/wmt/workload.py
@@ -60,8 +60,9 @@ def num_eval_train_examples(self) -> int:
# Round up from num_validation_examples (which is the default for
# num_eval_train_examples) to the next multiple of eval_batch_size, so that
# we don't have to extract the correctly sized subset of the training data.
- rounded_up_multiple = math.ceil(self.num_validation_examples /
- self.eval_batch_size)
+ rounded_up_multiple = math.ceil(
+ self.num_validation_examples / self.eval_batch_size
+ )
return rounded_up_multiple * self.eval_batch_size
@property
@@ -115,23 +116,26 @@ def activation(self) -> str:
def glu(self) -> bool:
return False
- def _build_input_queue(self,
- data_rng: jax.random.PRNGKey,
- split: str,
- data_dir: str,
- global_batch_size: int,
- num_batches: Optional[int] = None,
- repeat_final_dataset: bool = False):
+ def _build_input_queue(
+ self,
+ data_rng: jax.random.PRNGKey,
+ split: str,
+ data_dir: str,
+ global_batch_size: int,
+ num_batches: Optional[int] = None,
+ repeat_final_dataset: bool = False,
+ ):
is_training = split == 'train'
ds, self._tokenizer = input_pipeline.get_wmt_dataset(
- data_rng,
- split,
- data_dir,
- is_training=is_training,
- vocab_size=self._vocab_size,
- global_batch_size=global_batch_size,
- num_batches=num_batches,
- repeat_final_dataset=repeat_final_dataset)
+ data_rng,
+ split,
+ data_dir,
+ is_training=is_training,
+ vocab_size=self._vocab_size,
+ global_batch_size=global_batch_size,
+ num_batches=num_batches,
+ repeat_final_dataset=repeat_final_dataset,
+ )
# Separate function is necessary because the code above has to be executed
# when _build_input_queue is called (not when next() is first called on it).
@@ -148,19 +152,21 @@ def _input_queue_generator():
@abc.abstractmethod
def _normalize_eval_metrics(
- self, num_examples: int, total_metrics: Dict[str,
- Any]) -> Dict[str, float]:
+ self, num_examples: int, total_metrics: Dict[str, Any]
+ ) -> Dict[str, float]:
"""Normalize eval metrics."""
- def _eval_model_on_split(self,
- split: str,
- num_examples: int,
- global_batch_size: int,
- params: spec.ParameterContainer,
- model_state: spec.ModelAuxiliaryState,
- rng: spec.RandomState,
- data_dir: str,
- global_step: int = 0) -> Dict[str, float]:
+ def _eval_model_on_split(
+ self,
+ split: str,
+ num_examples: int,
+ global_batch_size: int,
+ params: spec.ParameterContainer,
+ model_state: spec.ModelAuxiliaryState,
+ rng: spec.RandomState,
+ data_dir: str,
+ global_step: int = 0,
+ ) -> Dict[str, float]:
"""Run a full evaluation of the model."""
del model_state
del global_step
@@ -168,12 +174,13 @@ def _eval_model_on_split(self,
if split not in self._eval_iters:
# These iterators will repeat indefinitely.
self._eval_iters[split] = self._build_input_queue(
- rng,
- split,
- data_dir,
- global_batch_size,
- num_batches,
- repeat_final_dataset=True)
+ rng,
+ split,
+ data_dir,
+ global_batch_size,
+ num_batches,
+ repeat_final_dataset=True,
+ )
eval_metrics = {}
for _ in range(num_batches):
@@ -186,16 +193,17 @@ def _eval_model_on_split(self,
eval_results = self._normalize_eval_metrics(num_examples, eval_metrics)
eval_results['bleu'] = self.translate_and_calculate_bleu(
- params=params,
- ds_iter=self._eval_iters[split],
- num_batches=num_batches,
- max_predict_length=256)
+ params=params,
+ ds_iter=self._eval_iters[split],
+ num_batches=num_batches,
+ max_predict_length=256,
+ )
return eval_results
def compute_weighted_accuracy(
- self, logits: spec.Tensor, targets: spec.Tensor,
- weights: spec.Tensor) -> Tuple[spec.Tensor, spec.Tensor]:
+ self, logits: spec.Tensor, targets: spec.Tensor, weights: spec.Tensor
+ ) -> Tuple[spec.Tensor, spec.Tensor]:
"""Compute weighted accuracy for log probs and targets.
Args:
@@ -207,8 +215,10 @@ def compute_weighted_accuracy(
Tuple of scalar summed accuracy and batch normalizing factor.
"""
if logits.ndim != targets.ndim + 1:
- raise ValueError(f'Incorrect shapes. Got shape {logits.shape} logits and '
- f'{targets.shape} targets.')
+ raise ValueError(
+ f'Incorrect shapes. Got shape {logits.shape} logits and '
+ f'{targets.shape} targets.'
+ )
accuracy = (logits.argmax(-1) == targets) * weights
normalizing_factor = weights.sum()
return accuracy.sum(), normalizing_factor
@@ -216,17 +226,18 @@ def compute_weighted_accuracy(
def _decode_tokens(self, toks: spec.Tensor) -> spec.Tensor:
if isinstance(toks, torch.Tensor):
toks = toks.cpu().numpy()
- valid_toks = toks[:np.argmax(toks == decode.EOS_ID) + 1].astype(np.int32)
+ valid_toks = toks[: np.argmax(toks == decode.EOS_ID) + 1].astype(np.int32)
return self._tokenizer.detokenize(valid_toks).numpy().decode('utf-8')
# Does NOT apply regularization, which is left to the submitter to do in
# `update_params`.
def loss_fn(
- self,
- label_batch: spec.Tensor, # Dense or one-hot labels.
- logits_batch: spec.Tensor,
- mask_batch: Optional[spec.Tensor] = None,
- label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable
+ self,
+ label_batch: spec.Tensor, # Dense or one-hot labels.
+ logits_batch: spec.Tensor,
+ mask_batch: Optional[spec.Tensor] = None,
+ label_smoothing: float = 0.0,
+ ) -> Dict[str, spec.Tensor]: # differentiable
"""Evaluate the (masked) loss function at (label_batch, logits_batch).
Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of
@@ -234,7 +245,8 @@ def loss_fn(
(not synced across devices).
"""
return self.compute_weighted_cross_entropy(
- logits_batch,
- label_batch,
- weights=mask_batch,
- label_smoothing=label_smoothing)
+ logits_batch,
+ label_batch,
+ weights=mask_batch,
+ label_smoothing=label_smoothing,
+ )
diff --git a/algoperf/workloads/workloads.py b/algoperf/workloads/workloads.py
index 4712f4e25..4dd4717e9 100644
--- a/algoperf/workloads/workloads.py
+++ b/algoperf/workloads/workloads.py
@@ -1,5 +1,5 @@
-""" Registry of workload info
-"""
+"""Registry of workload info"""
+
import importlib
import inspect
import os
@@ -9,149 +9,151 @@
BASE_WORKLOADS_DIR = 'algoperf/workloads/'
WORKLOADS = {
- 'cifar': {
- 'workload_path': 'cifar/cifar', 'workload_class_name': 'CifarWorkload'
- },
- 'criteo1tb': {
- 'workload_path': 'criteo1tb/criteo1tb',
- 'workload_class_name': 'Criteo1TbDlrmSmallWorkload',
- },
- 'criteo1tb_test': {
- 'workload_path': 'criteo1tb/criteo1tb',
- 'workload_class_name': 'Criteo1TbDlrmSmallTestWorkload',
- },
- 'criteo1tb_layernorm': {
- 'workload_path': 'criteo1tb/criteo1tb',
- 'workload_class_name': 'Criteo1TbDlrmSmallLayerNormWorkload'
- },
- 'criteo1tb_embed_init': {
- 'workload_path': 'criteo1tb/criteo1tb',
- 'workload_class_name': 'Criteo1TbDlrmSmallEmbedInitWorkload'
- },
- 'criteo1tb_resnet': {
- 'workload_path': 'criteo1tb/criteo1tb',
- 'workload_class_name': 'Criteo1TbDlrmSmallResNetWorkload'
- },
- 'fastmri': {
- 'workload_path': 'fastmri/fastmri',
- 'workload_class_name': 'FastMRIWorkload',
- },
- 'fastmri_model_size': {
- 'workload_path': 'fastmri/fastmri',
- 'workload_class_name': 'FastMRIModelSizeWorkload',
- },
- 'fastmri_tanh': {
- 'workload_path': 'fastmri/fastmri',
- 'workload_class_name': 'FastMRITanhWorkload',
- },
- 'fastmri_layernorm': {
- 'workload_path': 'fastmri/fastmri',
- 'workload_class_name': 'FastMRILayerNormWorkload',
- },
- 'imagenet_resnet': {
- 'workload_path': 'imagenet_resnet/imagenet',
- 'workload_class_name': 'ImagenetResNetWorkload',
- },
- 'imagenet_resnet_silu': {
- 'workload_path': 'imagenet_resnet/imagenet',
- 'workload_class_name': 'ImagenetResNetSiLUWorkload',
- },
- 'imagenet_resnet_gelu': {
- 'workload_path': 'imagenet_resnet/imagenet',
- 'workload_class_name': 'ImagenetResNetGELUWorkload',
- },
- 'imagenet_resnet_large_bn_init': {
- 'workload_path': 'imagenet_resnet/imagenet',
- 'workload_class_name': 'ImagenetResNetLargeBNScaleWorkload',
- },
- 'imagenet_vit': {
- 'workload_path': 'imagenet_vit/imagenet',
- 'workload_class_name': 'ImagenetVitWorkload',
- },
- 'imagenet_vit_glu': {
- 'workload_path': 'imagenet_vit/imagenet',
- 'workload_class_name': 'ImagenetVitGluWorkload',
- },
- 'imagenet_vit_post_ln': {
- 'workload_path': 'imagenet_vit/imagenet',
- 'workload_class_name': 'ImagenetVitPostLNWorkload',
- },
- 'imagenet_vit_map': {
- 'workload_path': 'imagenet_vit/imagenet',
- 'workload_class_name': 'ImagenetVitMapWorkload',
- },
- 'librispeech_conformer': {
- 'workload_path': 'librispeech_conformer/librispeech',
- 'workload_class_name': 'LibriSpeechConformerWorkload',
- },
- 'librispeech_conformer_attention_temperature': {
- 'workload_path':
- 'librispeech_conformer/librispeech',
- 'workload_class_name':
- 'LibriSpeechConformerAttentionTemperatureWorkload',
- },
- 'librispeech_conformer_layernorm': {
- 'workload_path': 'librispeech_conformer/librispeech',
- 'workload_class_name': 'LibriSpeechConformerLayerNormWorkload',
- },
- 'librispeech_conformer_gelu': {
- 'workload_path': 'librispeech_conformer/librispeech',
- 'workload_class_name': 'LibriSpeechConformerGeluWorkload',
- },
- 'librispeech_deepspeech': {
- 'workload_path': 'librispeech_deepspeech/librispeech',
- 'workload_class_name': 'LibriSpeechDeepSpeechWorkload',
- },
- 'librispeech_deepspeech_tanh': {
- 'workload_path': 'librispeech_deepspeech/librispeech',
- 'workload_class_name': 'LibriSpeechDeepSpeechTanhWorkload',
- },
- 'librispeech_deepspeech_no_resnet': {
- 'workload_path': 'librispeech_deepspeech/librispeech',
- 'workload_class_name': 'LibriSpeechDeepSpeechNoResNetWorkload',
- },
- 'librispeech_deepspeech_norm_and_spec_aug': {
- 'workload_path': 'librispeech_deepspeech/librispeech',
- 'workload_class_name': 'LibriSpeechDeepSpeechNormAndSpecAugWorkload',
- },
- 'mnist': {
- 'workload_path': 'mnist/mnist', 'workload_class_name': 'MnistWorkload'
- },
- 'ogbg': {
- 'workload_path': 'ogbg/ogbg', 'workload_class_name': 'OgbgWorkload'
- },
- 'ogbg_gelu': {
- 'workload_path': 'ogbg/ogbg', 'workload_class_name': 'OgbgGeluWorkload'
- },
- 'ogbg_silu': {
- 'workload_path': 'ogbg/ogbg', 'workload_class_name': 'OgbgSiluWorkload'
- },
- 'ogbg_model_size': {
- 'workload_path': 'ogbg/ogbg',
- 'workload_class_name': 'OgbgModelSizeWorkload'
- },
- 'wmt': {'workload_path': 'wmt/wmt', 'workload_class_name': 'WmtWorkload'},
- 'wmt_post_ln': {
- 'workload_path': 'wmt/wmt', 'workload_class_name': 'WmtWorkloadPostLN'
- },
- 'wmt_attention_temp': {
- 'workload_path': 'wmt/wmt',
- 'workload_class_name': 'WmtWorkloadAttentionTemp'
- },
- 'wmt_glu_tanh': {
- 'workload_path': 'wmt/wmt', 'workload_class_name': 'WmtWorkloadGLUTanH'
- },
+ 'cifar': {
+ 'workload_path': 'cifar/cifar',
+ 'workload_class_name': 'CifarWorkload',
+ },
+ 'criteo1tb': {
+ 'workload_path': 'criteo1tb/criteo1tb',
+ 'workload_class_name': 'Criteo1TbDlrmSmallWorkload',
+ },
+ 'criteo1tb_test': {
+ 'workload_path': 'criteo1tb/criteo1tb',
+ 'workload_class_name': 'Criteo1TbDlrmSmallTestWorkload',
+ },
+ 'criteo1tb_layernorm': {
+ 'workload_path': 'criteo1tb/criteo1tb',
+ 'workload_class_name': 'Criteo1TbDlrmSmallLayerNormWorkload',
+ },
+ 'criteo1tb_embed_init': {
+ 'workload_path': 'criteo1tb/criteo1tb',
+ 'workload_class_name': 'Criteo1TbDlrmSmallEmbedInitWorkload',
+ },
+ 'criteo1tb_resnet': {
+ 'workload_path': 'criteo1tb/criteo1tb',
+ 'workload_class_name': 'Criteo1TbDlrmSmallResNetWorkload',
+ },
+ 'fastmri': {
+ 'workload_path': 'fastmri/fastmri',
+ 'workload_class_name': 'FastMRIWorkload',
+ },
+ 'fastmri_model_size': {
+ 'workload_path': 'fastmri/fastmri',
+ 'workload_class_name': 'FastMRIModelSizeWorkload',
+ },
+ 'fastmri_tanh': {
+ 'workload_path': 'fastmri/fastmri',
+ 'workload_class_name': 'FastMRITanhWorkload',
+ },
+ 'fastmri_layernorm': {
+ 'workload_path': 'fastmri/fastmri',
+ 'workload_class_name': 'FastMRILayerNormWorkload',
+ },
+ 'imagenet_resnet': {
+ 'workload_path': 'imagenet_resnet/imagenet',
+ 'workload_class_name': 'ImagenetResNetWorkload',
+ },
+ 'imagenet_resnet_silu': {
+ 'workload_path': 'imagenet_resnet/imagenet',
+ 'workload_class_name': 'ImagenetResNetSiLUWorkload',
+ },
+ 'imagenet_resnet_gelu': {
+ 'workload_path': 'imagenet_resnet/imagenet',
+ 'workload_class_name': 'ImagenetResNetGELUWorkload',
+ },
+ 'imagenet_resnet_large_bn_init': {
+ 'workload_path': 'imagenet_resnet/imagenet',
+ 'workload_class_name': 'ImagenetResNetLargeBNScaleWorkload',
+ },
+ 'imagenet_vit': {
+ 'workload_path': 'imagenet_vit/imagenet',
+ 'workload_class_name': 'ImagenetVitWorkload',
+ },
+ 'imagenet_vit_glu': {
+ 'workload_path': 'imagenet_vit/imagenet',
+ 'workload_class_name': 'ImagenetVitGluWorkload',
+ },
+ 'imagenet_vit_post_ln': {
+ 'workload_path': 'imagenet_vit/imagenet',
+ 'workload_class_name': 'ImagenetVitPostLNWorkload',
+ },
+ 'imagenet_vit_map': {
+ 'workload_path': 'imagenet_vit/imagenet',
+ 'workload_class_name': 'ImagenetVitMapWorkload',
+ },
+ 'librispeech_conformer': {
+ 'workload_path': 'librispeech_conformer/librispeech',
+ 'workload_class_name': 'LibriSpeechConformerWorkload',
+ },
+ 'librispeech_conformer_attention_temperature': {
+ 'workload_path': 'librispeech_conformer/librispeech',
+ 'workload_class_name': 'LibriSpeechConformerAttentionTemperatureWorkload',
+ },
+ 'librispeech_conformer_layernorm': {
+ 'workload_path': 'librispeech_conformer/librispeech',
+ 'workload_class_name': 'LibriSpeechConformerLayerNormWorkload',
+ },
+ 'librispeech_conformer_gelu': {
+ 'workload_path': 'librispeech_conformer/librispeech',
+ 'workload_class_name': 'LibriSpeechConformerGeluWorkload',
+ },
+ 'librispeech_deepspeech': {
+ 'workload_path': 'librispeech_deepspeech/librispeech',
+ 'workload_class_name': 'LibriSpeechDeepSpeechWorkload',
+ },
+ 'librispeech_deepspeech_tanh': {
+ 'workload_path': 'librispeech_deepspeech/librispeech',
+ 'workload_class_name': 'LibriSpeechDeepSpeechTanhWorkload',
+ },
+ 'librispeech_deepspeech_no_resnet': {
+ 'workload_path': 'librispeech_deepspeech/librispeech',
+ 'workload_class_name': 'LibriSpeechDeepSpeechNoResNetWorkload',
+ },
+ 'librispeech_deepspeech_norm_and_spec_aug': {
+ 'workload_path': 'librispeech_deepspeech/librispeech',
+ 'workload_class_name': 'LibriSpeechDeepSpeechNormAndSpecAugWorkload',
+ },
+ 'mnist': {
+ 'workload_path': 'mnist/mnist',
+ 'workload_class_name': 'MnistWorkload',
+ },
+ 'ogbg': {'workload_path': 'ogbg/ogbg', 'workload_class_name': 'OgbgWorkload'},
+ 'ogbg_gelu': {
+ 'workload_path': 'ogbg/ogbg',
+ 'workload_class_name': 'OgbgGeluWorkload',
+ },
+ 'ogbg_silu': {
+ 'workload_path': 'ogbg/ogbg',
+ 'workload_class_name': 'OgbgSiluWorkload',
+ },
+ 'ogbg_model_size': {
+ 'workload_path': 'ogbg/ogbg',
+ 'workload_class_name': 'OgbgModelSizeWorkload',
+ },
+ 'wmt': {'workload_path': 'wmt/wmt', 'workload_class_name': 'WmtWorkload'},
+ 'wmt_post_ln': {
+ 'workload_path': 'wmt/wmt',
+ 'workload_class_name': 'WmtWorkloadPostLN',
+ },
+ 'wmt_attention_temp': {
+ 'workload_path': 'wmt/wmt',
+ 'workload_class_name': 'WmtWorkloadAttentionTemp',
+ },
+ 'wmt_glu_tanh': {
+ 'workload_path': 'wmt/wmt',
+ 'workload_class_name': 'WmtWorkloadGLUTanH',
+ },
}
BASE_WORKLOADS = [
- 'criteo1tb',
- 'fastmri',
- 'imagenet_resnet',
- 'imagenet_vit',
- 'librispeech_conformer',
- 'librispeech_deepspeech',
- 'ogbg',
- 'wmt'
+ 'criteo1tb',
+ 'fastmri',
+ 'imagenet_resnet',
+ 'imagenet_vit',
+ 'librispeech_conformer',
+ 'librispeech_deepspeech',
+ 'ogbg',
+ 'wmt',
]
@@ -171,10 +173,12 @@ def convert_filepath_to_module(path: str):
return base.replace('/', '.')
-def import_workload(workload_path: str,
- workload_class_name: str,
- return_class=False,
- workload_init_kwargs=None) -> spec.Workload:
+def import_workload(
+ workload_path: str,
+ workload_class_name: str,
+ return_class=False,
+ workload_init_kwargs=None,
+) -> spec.Workload:
"""Import and add the workload to the registry.
This importlib loading is nice to have because it allows runners to avoid
@@ -206,9 +210,10 @@ def import_workload(workload_path: str,
break
if workload_class is None:
raise ValueError(
- f'Could not find member {workload_class_name} in {workload_path}. '
- 'Make sure the Workload class is spelled correctly and defined in '
- 'the top scope of the module.')
+ f'Could not find member {workload_class_name} in {workload_path}. '
+ 'Make sure the Workload class is spelled correctly and defined in '
+ 'the top scope of the module.'
+ )
if return_class:
return workload_class
return workload_class(**workload_init_kwargs)
From 97255544b1daf604dee84c9f1de9f8d122d2e121 Mon Sep 17 00:00:00 2001
From: Frank Schneider
Date: Mon, 23 Jun 2025 10:21:23 +0200
Subject: [PATCH 108/123] Format docker/
---
docker/scripts/singularity_converter.py | 23 ++++++++++++-----------
1 file changed, 12 insertions(+), 11 deletions(-)
diff --git a/docker/scripts/singularity_converter.py b/docker/scripts/singularity_converter.py
index 48c521009..a816eb5c2 100644
--- a/docker/scripts/singularity_converter.py
+++ b/docker/scripts/singularity_converter.py
@@ -15,26 +15,27 @@
from spython.main.parse.writers import get_writer
# globals
-ENTRY_POINT = "/bin/bash" # seems to be a good default
+ENTRY_POINT = '/bin/bash' # seems to be a good default
FORCE = False # seems to be a good default
#
-parser = argparse.ArgumentParser(description="Custom Singularity converter")
+parser = argparse.ArgumentParser(description='Custom Singularity converter')
parser.add_argument(
- "-i", "--input", type=str, help="Docker input path", default="Dockerfile")
+ '-i', '--input', type=str, help='Docker input path', default='Dockerfile'
+)
parser.add_argument(
- "-o",
- "--output",
- type=str,
- help="Singularity output path",
- default="Singularity.def",
+ '-o',
+ '--output',
+ type=str,
+ help='Singularity output path',
+ default='Singularity.def',
)
args = parser.parse_args()
INPUT_DOCKERFILE_PATH = args.input
OUTPUT_SINGULARITY_PATH = args.output
# create Docker parser and Singularity writer
-parser = get_parser("docker")
-writer = get_writer("singularity")
+parser = get_parser('docker')
+writer = get_writer('singularity')
# parse Dockerfile into Singularity and suppress %files commands
recipeParser = parser(INPUT_DOCKERFILE_PATH)
@@ -44,5 +45,5 @@
# convert to string and save to output file
result = recipeWriter.convert(runscript=ENTRY_POINT, force=FORCE)
-with open(OUTPUT_SINGULARITY_PATH, "w") as f:
+with open(OUTPUT_SINGULARITY_PATH, 'w') as f:
f.write(result)
From 7b18fff74c8d08b4f6dd338ab652c67d39252e4a Mon Sep 17 00:00:00 2001
From: Frank Schneider
Date: Mon, 23 Jun 2025 10:32:47 +0200
Subject: [PATCH 109/123] Lint tests/
---
tests/modeldiffs/diff.py | 7 +++----
tests/modeldiffs/imagenet_resnet/gelu_compare.py | 3 +--
tests/modeldiffs/imagenet_resnet/silu_compare.py | 3 +--
tests/modeldiffs/librispeech_deepspeech/compare.py | 6 +++---
.../librispeech_deepspeech_noresnet/compare.py | 6 ++++--
.../librispeech_deepspeech_normaug/compare.py | 6 ++++--
.../librispeech_deepspeech_tanh/compare.py | 6 ++++--
tests/modeldiffs/torch2jax_utils.py | 2 +-
tests/modeldiffs/vanilla_sgd_jax.py | 10 +++++-----
tests/modeldiffs/vanilla_sgd_pytorch.py | 8 ++++----
tests/submission_runner_test.py | 8 +++-----
tests/test_baselines.py | 8 +++-----
tests/test_param_shapes.py | 3 ---
tests/test_param_types.py | 6 +-----
tests/test_ssim.py | 3 +--
tests/test_traindiffs.py | 12 ++++++------
.../imagenet_resnet/imagenet_jax/workload_test.py | 2 +-
17 files changed, 45 insertions(+), 54 deletions(-)
diff --git a/tests/modeldiffs/diff.py b/tests/modeldiffs/diff.py
index 8449c3241..a74159028 100644
--- a/tests/modeldiffs/diff.py
+++ b/tests/modeldiffs/diff.py
@@ -1,11 +1,10 @@
-from flax import jax_utils
-from flax.core import FrozenDict
import jax
import numpy as np
import torch
+from flax import jax_utils
+from flax.core import FrozenDict
-from tests.modeldiffs.torch2jax_utils import Torch2Jax
-from tests.modeldiffs.torch2jax_utils import value_transform
+from tests.modeldiffs.torch2jax_utils import Torch2Jax, value_transform
# pylint: disable=dangerous-default-value
diff --git a/tests/modeldiffs/imagenet_resnet/gelu_compare.py b/tests/modeldiffs/imagenet_resnet/gelu_compare.py
index 8aa48382d..a92712ddc 100644
--- a/tests/modeldiffs/imagenet_resnet/gelu_compare.py
+++ b/tests/modeldiffs/imagenet_resnet/gelu_compare.py
@@ -14,8 +14,7 @@
ImagenetResNetGELUWorkload as PyTorchWorkload,
)
from tests.modeldiffs.diff import ModelDiffRunner
-from tests.modeldiffs.imagenet_resnet.compare import key_transform
-from tests.modeldiffs.imagenet_resnet.compare import sd_transform
+from tests.modeldiffs.imagenet_resnet.compare import key_transform, sd_transform
if __name__ == '__main__':
# pylint: disable=locally-disabled, not-callable
diff --git a/tests/modeldiffs/imagenet_resnet/silu_compare.py b/tests/modeldiffs/imagenet_resnet/silu_compare.py
index 393badd18..bbbfd082b 100644
--- a/tests/modeldiffs/imagenet_resnet/silu_compare.py
+++ b/tests/modeldiffs/imagenet_resnet/silu_compare.py
@@ -14,8 +14,7 @@
ImagenetResNetSiLUWorkload as PyTorchWorkload,
)
from tests.modeldiffs.diff import ModelDiffRunner
-from tests.modeldiffs.imagenet_resnet.compare import key_transform
-from tests.modeldiffs.imagenet_resnet.compare import sd_transform
+from tests.modeldiffs.imagenet_resnet.compare import key_transform, sd_transform
if __name__ == '__main__':
# pylint: disable=locally-disabled, not-callable
diff --git a/tests/modeldiffs/librispeech_deepspeech/compare.py b/tests/modeldiffs/librispeech_deepspeech/compare.py
index 12cd11513..bd1073524 100644
--- a/tests/modeldiffs/librispeech_deepspeech/compare.py
+++ b/tests/modeldiffs/librispeech_deepspeech/compare.py
@@ -58,9 +58,9 @@ def sd_transform(sd):
else:
out[k] = sd[k]
elif 'LSTM' in ''.join(k):
- l = out.get(k[:-1], dict())
- l[k[-1]] = sd[k]
- out[k[:-1]] = l
+ l_tmp = out.get(k[:-1], dict())
+ l_tmp[k[-1]] = sd[k]
+ out[k[:-1]] = l_tmp
else:
out[k] = sd[k]
keys_to_del = []
diff --git a/tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py b/tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py
index 6a719a84a..8593894e4 100644
--- a/tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py
+++ b/tests/modeldiffs/librispeech_deepspeech_noresnet/compare.py
@@ -14,8 +14,10 @@
LibriSpeechDeepSpeechTanhWorkload as PyTorchWorkload,
)
from tests.modeldiffs.diff import ModelDiffRunner
-from tests.modeldiffs.librispeech_deepspeech.compare import key_transform
-from tests.modeldiffs.librispeech_deepspeech.compare import sd_transform
+from tests.modeldiffs.librispeech_deepspeech.compare import (
+ key_transform,
+ sd_transform,
+)
if __name__ == '__main__':
# pylint: disable=locally-disabled, not-callable
diff --git a/tests/modeldiffs/librispeech_deepspeech_normaug/compare.py b/tests/modeldiffs/librispeech_deepspeech_normaug/compare.py
index c8820d397..27e4760a6 100644
--- a/tests/modeldiffs/librispeech_deepspeech_normaug/compare.py
+++ b/tests/modeldiffs/librispeech_deepspeech_normaug/compare.py
@@ -14,8 +14,10 @@
LibriSpeechDeepSpeechNormAndSpecAugWorkload as PyTorchWorkload,
)
from tests.modeldiffs.diff import ModelDiffRunner
-from tests.modeldiffs.librispeech_deepspeech.compare import key_transform
-from tests.modeldiffs.librispeech_deepspeech.compare import sd_transform
+from tests.modeldiffs.librispeech_deepspeech.compare import (
+ key_transform,
+ sd_transform,
+)
if __name__ == '__main__':
# pylint: disable=locally-disabled, not-callable
diff --git a/tests/modeldiffs/librispeech_deepspeech_tanh/compare.py b/tests/modeldiffs/librispeech_deepspeech_tanh/compare.py
index 0882f3d1e..7990f063b 100644
--- a/tests/modeldiffs/librispeech_deepspeech_tanh/compare.py
+++ b/tests/modeldiffs/librispeech_deepspeech_tanh/compare.py
@@ -14,8 +14,10 @@
LibriSpeechDeepSpeechNoResNetWorkload as PyTorchWorkload,
)
from tests.modeldiffs.diff import ModelDiffRunner
-from tests.modeldiffs.librispeech_deepspeech.compare import key_transform
-from tests.modeldiffs.librispeech_deepspeech.compare import sd_transform
+from tests.modeldiffs.librispeech_deepspeech.compare import (
+ key_transform,
+ sd_transform,
+)
if __name__ == '__main__':
# pylint: disable=locally-disabled, not-callable
diff --git a/tests/modeldiffs/torch2jax_utils.py b/tests/modeldiffs/torch2jax_utils.py
index 7c77a152c..4c95ca7e4 100644
--- a/tests/modeldiffs/torch2jax_utils.py
+++ b/tests/modeldiffs/torch2jax_utils.py
@@ -1,5 +1,5 @@
-from collections import Counter
import pprint
+from collections import Counter
def jax_like_pytorch_statedict(model, state_dict, keys=None):
diff --git a/tests/modeldiffs/vanilla_sgd_jax.py b/tests/modeldiffs/vanilla_sgd_jax.py
index e80e70b8e..aa7bebd4f 100644
--- a/tests/modeldiffs/vanilla_sgd_jax.py
+++ b/tests/modeldiffs/vanilla_sgd_jax.py
@@ -1,15 +1,15 @@
-from flax import jax_utils
import jax
import jax.numpy as jnp
import optax
+from flax import jax_utils
from algoperf import spec
-from reference_algorithms.target_setting_algorithms.data_selection import (
+from reference_algorithms.target_setting_algorithms.data_selection import ( # noqa: F401
data_selection,
-) # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.jax_submission_base import (
+)
+from reference_algorithms.target_setting_algorithms.jax_submission_base import ( # noqa: F401
update_params,
-) # pylint: disable=unused-import
+)
def init_optimizer_state(
diff --git a/tests/modeldiffs/vanilla_sgd_pytorch.py b/tests/modeldiffs/vanilla_sgd_pytorch.py
index d6613479e..6448ac097 100644
--- a/tests/modeldiffs/vanilla_sgd_pytorch.py
+++ b/tests/modeldiffs/vanilla_sgd_pytorch.py
@@ -1,12 +1,12 @@
import torch
from algoperf import spec
-from reference_algorithms.target_setting_algorithms.data_selection import (
+from reference_algorithms.target_setting_algorithms.data_selection import ( # noqa: F401
data_selection,
-) # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.pytorch_submission_base import (
+)
+from reference_algorithms.target_setting_algorithms.pytorch_submission_base import ( # noqa: F401
update_params,
-) # pylint: disable=unused-import
+)
def init_optimizer_state(
diff --git a/tests/submission_runner_test.py b/tests/submission_runner_test.py
index b9beb9101..c6c993b7b 100644
--- a/tests/submission_runner_test.py
+++ b/tests/submission_runner_test.py
@@ -9,13 +9,11 @@
import os
import sys
-from absl import flags
-from absl import logging
-from absl.testing import absltest
-from absl.testing import parameterized
+from absl import flags, logging
+from absl.testing import absltest, parameterized
-from algoperf.profiler import PassThroughProfiler
import submission_runner
+from algoperf.profiler import PassThroughProfiler
FLAGS = flags.FLAGS
# Needed to avoid UnparsedFlagAccessError
diff --git a/tests/test_baselines.py b/tests/test_baselines.py
index 9ebc50222..c5097a567 100644
--- a/tests/test_baselines.py
+++ b/tests/test_baselines.py
@@ -8,14 +8,12 @@
import os
import sys
-from absl import flags
-from absl import logging
-from absl.testing import absltest
-from absl.testing import parameterized
+from absl import flags, logging
+from absl.testing import absltest, parameterized
+import submission_runner
from algoperf.profiler import PassThroughProfiler
from algoperf.workloads import workloads
-import submission_runner
FLAGS = flags.FLAGS
# Needed to avoid UnparsedFlagAccessError
diff --git a/tests/test_param_shapes.py b/tests/test_param_shapes.py
index 1badd39ed..2243ce52e 100644
--- a/tests/test_param_shapes.py
+++ b/tests/test_param_shapes.py
@@ -5,8 +5,6 @@
import pytest
from flax.core import FrozenDict
-# isort: skip_file
-# pylint:disable=line-too-long
from algoperf.workloads.cifar.cifar_jax.workload import (
CifarWorkload as JaxCifarWorkload,
)
@@ -67,7 +65,6 @@
from algoperf.workloads.wmt.wmt_pytorch.workload import (
WmtWorkload as PyTorchWmtWorkload,
)
-# pylint:enable=line-too-long
WORKLOADS = [
'cifar',
diff --git a/tests/test_param_types.py b/tests/test_param_types.py
index 1583342ff..9f14f7dd8 100644
--- a/tests/test_param_types.py
+++ b/tests/test_param_types.py
@@ -1,11 +1,8 @@
import jax
import pytest
-
from absl import logging
-from algoperf import spec
-# isort: skip_file
-# pylint:disable=line-too-long
+from algoperf import spec
from algoperf.workloads.cifar.cifar_jax.workload import (
CifarWorkload as JaxCifarWorkload,
)
@@ -66,7 +63,6 @@
from algoperf.workloads.wmt.wmt_pytorch.workload import (
WmtWorkload as PyTorchWmtWorkload,
)
-# pylint:enable=line-too-long
WORKLOADS = [
'cifar',
diff --git a/tests/test_ssim.py b/tests/test_ssim.py
index 7d730c251..dcb3f25e0 100644
--- a/tests/test_ssim.py
+++ b/tests/test_ssim.py
@@ -3,11 +3,10 @@
import os
from typing import Tuple
-from absl.testing import absltest
-from absl.testing import parameterized
import jax.numpy as jnp
import numpy as np
import torch
+from absl.testing import absltest, parameterized
from algoperf.pytorch_utils import pytorch_setup
from algoperf.workloads.fastmri.fastmri_jax.ssim import (
diff --git a/tests/test_traindiffs.py b/tests/test_traindiffs.py
index b1982a3bf..8acfc855a 100644
--- a/tests/test_traindiffs.py
+++ b/tests/test_traindiffs.py
@@ -6,13 +6,10 @@
import pickle
import subprocess
-from subprocess import DEVNULL
-from subprocess import run
-from subprocess import STDOUT
+from subprocess import DEVNULL, STDOUT, run
from absl import flags
-from absl.testing import absltest
-from absl.testing import parameterized
+from absl.testing import absltest, parameterized
from numpy import allclose
FLAGS = flags.FLAGS
@@ -92,7 +89,10 @@ def test_workload(self, workload):
'Train Loss (jax)',
'Train Loss (torch)',
]
- fmt = lambda l: '|' + '|'.join(map(lambda x: f'{x:^20s}', l)) + '|'
+
+ def fmt(line):
+ return '|' + '|'.join(map(lambda x: f'{x:^20s}', line)) + '|'
+
header = fmt(header)
pad = (len(header) - len((name))) // 2
print('=' * pad, name, '=' * (len(header) - len(name) - pad), sep='')
diff --git a/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py b/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py
index 3d06c9839..60a1af2f2 100644
--- a/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py
+++ b/tests/workloads/imagenet_resnet/imagenet_jax/workload_test.py
@@ -1,8 +1,8 @@
"""Tests for imagenet_resnet/imagenet_jax/workload.py."""
-from absl.testing import absltest
import jax
import jax.numpy as jnp
+from absl.testing import absltest
from algoperf import spec
from algoperf.workloads.imagenet_resnet.imagenet_jax.workload import (
From f34bb6d4d0c4d2e971ebef0ff307defc4b5cdbc5 Mon Sep 17 00:00:00 2001
From: Frank Schneider
Date: Mon, 23 Jun 2025 10:34:35 +0200
Subject: [PATCH 110/123] Lint submissions/
---
submissions/submission_checker.py | 1 -
1 file changed, 1 deletion(-)
diff --git a/submissions/submission_checker.py b/submissions/submission_checker.py
index fcc4e1faf..f8af9fb52 100644
--- a/submissions/submission_checker.py
+++ b/submissions/submission_checker.py
@@ -29,7 +29,6 @@
import argparse
import logging
import os
-import subprocess
SELF_TUNING = 'self_tuning'
EXTERNAL_TUNING = 'external_tuning'
From 0aeb545f1f78284936f58d8d6abc07fdf7b8da9a Mon Sep 17 00:00:00 2001
From: Frank Schneider
Date: Mon, 23 Jun 2025 10:47:44 +0200
Subject: [PATCH 111/123] Remove perf. profile tests as it is only a
placeholder
---
scoring/test_performance_profile.py | 23 -----------------------
1 file changed, 23 deletions(-)
delete mode 100644 scoring/test_performance_profile.py
diff --git a/scoring/test_performance_profile.py b/scoring/test_performance_profile.py
deleted file mode 100644
index 01c96de71..000000000
--- a/scoring/test_performance_profile.py
+++ /dev/null
@@ -1,23 +0,0 @@
-import os
-
-from absl.testing import absltest
-
-from scoring import performance_profile, scoring_utils
-
-
-class Test(absltest.TestCase):
- def test_get_workloads_time_to_target(self):
- # TODO(kasimbeg)
- pass
-
- def test_get_best_trial_index(self):
- # TODO(kasimbeg)
- pass
-
- def test_compute_performance_profiles(self):
- # TODO(kasimbeg)
- pass
-
-
-if __name__ == '__main__':
- absltest.main()
From 4802dfbdba1dab3919879da587ef3eb46d3cc310 Mon Sep 17 00:00:00 2001
From: Frank Schneider
Date: Mon, 23 Jun 2025 10:47:50 +0200
Subject: [PATCH 112/123] Lint scoring/
---
scoring/performance_profile.py | 2 +-
scoring/scoring_utils.py | 2 +-
scoring/test_scoring_utils.py | 4 +---
3 files changed, 3 insertions(+), 5 deletions(-)
diff --git a/scoring/performance_profile.py b/scoring/performance_profile.py
index d79f705d1..4f2ae9c57 100644
--- a/scoring/performance_profile.py
+++ b/scoring/performance_profile.py
@@ -52,7 +52,7 @@
try:
with open('held_out_workloads_algoperf_v05.json', 'r') as f:
HELDOUT_WORKLOADS = json.load(f)
-except:
+except FileNotFoundError:
HELDOUT_WORKLOADS = None
# These global variables have to be set according to the current set of
diff --git a/scoring/scoring_utils.py b/scoring/scoring_utils.py
index ab639f870..5be6c790c 100644
--- a/scoring/scoring_utils.py
+++ b/scoring/scoring_utils.py
@@ -200,7 +200,7 @@ def get_experiment_df(experiment_dir):
)
try:
trial_df = pd.read_csv(eval_measurements_filepath)
- except FileNotFoundError as e:
+ except FileNotFoundError:
logging.info(f'Could not read {eval_measurements_filepath}')
continue
data['trial'] = (trial, experiment_dir)
diff --git a/scoring/test_scoring_utils.py b/scoring/test_scoring_utils.py
index e3a5f7263..64e141976 100644
--- a/scoring/test_scoring_utils.py
+++ b/scoring/test_scoring_utils.py
@@ -1,8 +1,6 @@
-import os
-
from absl.testing import absltest
-from scoring import performance_profile, scoring_utils
+from scoring import scoring_utils
TEST_LOGFILE = 'test_data/adamw_fastmri_jax_04-18-2023-13-10-58.log'
TEST_DIR = 'test_data/experiment_dir'
From 5e97e7850fc7d3c53398c9eb2bdfd56c93eeb8cf Mon Sep 17 00:00:00 2001
From: Frank Schneider
Date: Mon, 23 Jun 2025 10:49:16 +0200
Subject: [PATCH 113/123] Lint prize_qualification_baselines/
---
.../external_tuning/jax_nadamw_full_budget.py | 4 +-
.../jax_nadamw_target_setting.py | 4 +-
.../pytorch_nadamw_full_budget.py | 8 +-
.../pytorch_nadamw_target_setting.py | 8 +-
.../external_tuning/tuning_search_space.json | 96 +++++++++----------
.../self_tuning/jax_nadamw_full_budget.py | 4 +-
.../self_tuning/jax_nadamw_target_setting.py | 4 +-
.../self_tuning/pytorch_nadamw_full_budget.py | 8 +-
.../pytorch_nadamw_target_setting.py | 8 +-
9 files changed, 65 insertions(+), 79 deletions(-)
diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py
index aa1a08f69..ed721d167 100644
--- a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py
+++ b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py
@@ -18,11 +18,11 @@
# isort: on
import chex
-from flax import jax_utils
import jax
-from jax import lax
import jax.numpy as jnp
import optax
+from flax import jax_utils
+from jax import lax
from algoperf import spec
diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py
index f1d7d62e0..fdcdd5348 100644
--- a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py
+++ b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py
@@ -18,11 +18,11 @@
# isort: on
import chex
-from flax import jax_utils
import jax
-from jax import lax
import jax.numpy as jnp
import optax
+from flax import jax_utils
+from jax import lax
from algoperf import spec
diff --git a/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py
index ecd299988..f6c2faa9d 100644
--- a/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py
+++ b/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py
@@ -3,13 +3,11 @@
import math
from typing import Any, Dict, Iterator, List, Optional, Tuple
-from absl import logging
import torch
-from torch import Tensor
import torch.distributed.nn as dist_nn
-from torch.optim.lr_scheduler import CosineAnnealingLR
-from torch.optim.lr_scheduler import LinearLR
-from torch.optim.lr_scheduler import SequentialLR
+from absl import logging
+from torch import Tensor
+from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from algoperf import spec
from algoperf.pytorch_utils import pytorch_setup
diff --git a/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py
index 0d8054135..68ff30b2a 100644
--- a/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py
+++ b/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py
@@ -3,13 +3,11 @@
import math
from typing import Any, Dict, Iterator, List, Optional, Tuple
-from absl import logging
import torch
-from torch import Tensor
import torch.distributed.nn as dist_nn
-from torch.optim.lr_scheduler import CosineAnnealingLR
-from torch.optim.lr_scheduler import LinearLR
-from torch.optim.lr_scheduler import SequentialLR
+from absl import logging
+from torch import Tensor
+from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from algoperf import spec
from algoperf.pytorch_utils import pytorch_setup
diff --git a/prize_qualification_baselines/external_tuning/tuning_search_space.json b/prize_qualification_baselines/external_tuning/tuning_search_space.json
index 199f77041..ad68372c6 100644
--- a/prize_qualification_baselines/external_tuning/tuning_search_space.json
+++ b/prize_qualification_baselines/external_tuning/tuning_search_space.json
@@ -1,53 +1,47 @@
[
- {
- "dropout_rate": 0.0,
- "label_smoothing": 0.1,
- "learning_rate": 0.001308209823469072,
- "one_minus_beta1": 0.02686663061,
- "beta2": 0.9981232922116359,
- "weight_decay": 0.16375311233774334,
- "warmup_factor": 0.1
- },
- {
- "dropout_rate": 0.0,
- "label_smoothing": 0.2,
- "learning_rate": 0.0008445074561975979,
- "one_minus_beta1": 0.11042418465,
- "beta2": 0.9978504782314613,
- "weight_decay": 0.08135402759553023,
- "warmup_factor": 0.05
- },
- {
- "dropout_rate": 0.0,
- "label_smoothing": 0.0,
- "learning_rate": 0.001308209823469072,
- "one_minus_beta1": 0.02686663061,
- "beta2": 0.9981232922116359,
- "weight_decay": 0.16375311233774334,
- "warmup_factor": 0.1
- },
- {
- "dropout_rate": 0.0,
- "label_smoothing": 0.0,
- "learning_rate": 0.004958460849689891,
- "one_minus_beta1": 0.13625575743,
- "beta2": 0.6291854735396584,
- "weight_decay": 0.1147386261512052,
- "warmup_factor": 0.02
- },
- {
- "dropout_rate": 0.1,
- "label_smoothing": 0.0,
- "learning_rate": 0.0017486387539278373,
- "one_minus_beta1": 0.06733926164,
- "beta2": 0.9955159689799007,
- "weight_decay": 0.08121616522670176,
- "warmup_factor": 0.02
- }
+ {
+ "dropout_rate": 0.0,
+ "label_smoothing": 0.1,
+ "learning_rate": 0.001308209823469072,
+ "one_minus_beta1": 0.02686663061,
+ "beta2": 0.9981232922116359,
+ "weight_decay": 0.16375311233774334,
+ "warmup_factor": 0.1
+ },
+ {
+ "dropout_rate": 0.0,
+ "label_smoothing": 0.2,
+ "learning_rate": 0.0008445074561975979,
+ "one_minus_beta1": 0.11042418465,
+ "beta2": 0.9978504782314613,
+ "weight_decay": 0.08135402759553023,
+ "warmup_factor": 0.05
+ },
+ {
+ "dropout_rate": 0.0,
+ "label_smoothing": 0.0,
+ "learning_rate": 0.001308209823469072,
+ "one_minus_beta1": 0.02686663061,
+ "beta2": 0.9981232922116359,
+ "weight_decay": 0.16375311233774334,
+ "warmup_factor": 0.1
+ },
+ {
+ "dropout_rate": 0.0,
+ "label_smoothing": 0.0,
+ "learning_rate": 0.004958460849689891,
+ "one_minus_beta1": 0.13625575743,
+ "beta2": 0.6291854735396584,
+ "weight_decay": 0.1147386261512052,
+ "warmup_factor": 0.02
+ },
+ {
+ "dropout_rate": 0.1,
+ "label_smoothing": 0.0,
+ "learning_rate": 0.0017486387539278373,
+ "one_minus_beta1": 0.06733926164,
+ "beta2": 0.9955159689799007,
+ "weight_decay": 0.08121616522670176,
+ "warmup_factor": 0.02
+ }
]
-
-
-
-
-
-
diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py
index fb322bd5a..0b4e5aba3 100644
--- a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py
+++ b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py
@@ -18,11 +18,11 @@
# isort: on
import chex
-from flax import jax_utils
import jax
-from jax import lax
import jax.numpy as jnp
import optax
+from flax import jax_utils
+from jax import lax
from algoperf import spec
diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py
index 99d996bb9..d3efc3a55 100644
--- a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py
+++ b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py
@@ -18,11 +18,11 @@
# isort: on
import chex
-from flax import jax_utils
import jax
-from jax import lax
import jax.numpy as jnp
import optax
+from flax import jax_utils
+from jax import lax
from algoperf import spec
diff --git a/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py
index cc54e3b4e..8e8adbeaa 100644
--- a/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py
+++ b/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py
@@ -4,13 +4,11 @@
import math
from typing import Any, Dict, Iterator, List, Optional, Tuple
-from absl import logging
import torch
-from torch import Tensor
import torch.distributed.nn as dist_nn
-from torch.optim.lr_scheduler import CosineAnnealingLR
-from torch.optim.lr_scheduler import LinearLR
-from torch.optim.lr_scheduler import SequentialLR
+from absl import logging
+from torch import Tensor
+from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from algoperf import spec
from algoperf.pytorch_utils import pytorch_setup
diff --git a/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py
index bd065dc06..5d7c444a4 100644
--- a/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py
+++ b/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py
@@ -4,13 +4,11 @@
import math
from typing import Any, Dict, Iterator, List, Optional, Tuple
-from absl import logging
import torch
-from torch import Tensor
import torch.distributed.nn as dist_nn
-from torch.optim.lr_scheduler import CosineAnnealingLR
-from torch.optim.lr_scheduler import LinearLR
-from torch.optim.lr_scheduler import SequentialLR
+from absl import logging
+from torch import Tensor
+from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from algoperf import spec
from algoperf.pytorch_utils import pytorch_setup
From 4ae54188b101229117897ec7662c95b954c5a69f Mon Sep 17 00:00:00 2001
From: Frank Schneider
Date: Mon, 23 Jun 2025 10:57:14 +0200
Subject: [PATCH 114/123] Lint datasets/
---
datasets/librispeech_preprocess.py | 9 ++++-----
datasets/librispeech_tokenizer.py | 6 +++---
2 files changed, 7 insertions(+), 8 deletions(-)
diff --git a/datasets/librispeech_preprocess.py b/datasets/librispeech_preprocess.py
index c419eb39b..1c216db46 100644
--- a/datasets/librispeech_preprocess.py
+++ b/datasets/librispeech_preprocess.py
@@ -4,16 +4,15 @@
import multiprocessing.dummy
import os
-from os.path import exists
import sys
import threading
import time
-from absl import logging
import numpy as np
import pandas as pd
-from pydub import AudioSegment
import tensorflow as tf
+from absl import logging
+from pydub import AudioSegment
from datasets import librispeech_tokenizer
@@ -84,8 +83,8 @@ def process(index):
return utterance_ids
with open(trans_file, 'r', encoding='UTF-8') as f:
- for l in f:
- utt, trans = l.strip().split(' ', maxsplit=1)
+ for line in f:
+ utt, trans = line.strip().split(' ', maxsplit=1)
audio_path = (
f'{data_folder}/{speaker_folder}/{chapter_folder}/{utt}.flac'
)
diff --git a/datasets/librispeech_tokenizer.py b/datasets/librispeech_tokenizer.py
index 5b9888cc2..d566d5716 100644
--- a/datasets/librispeech_tokenizer.py
+++ b/datasets/librispeech_tokenizer.py
@@ -8,10 +8,10 @@
import tempfile
from typing import Dict
-from absl import logging
import sentencepiece as spm
import tensorflow as tf
import tensorflow_text as tftxt
+from absl import logging
gfile = tf.io.gfile
copy = tf.io.gfile.copy
@@ -41,8 +41,8 @@ def dump_chars_for_training(data_folder, splits, maxchars: int = int(1e7)):
logging.info('path does not exist -> %s', trans_file)
continue
with open(trans_file, 'r', encoding='UTF-8') as f:
- for l in f:
- _, line = l.strip().split(' ', maxsplit=1)
+ for lines in f:
+ _, line = lines.strip().split(' ', maxsplit=1)
line = line + '\n'
char_count += len(line)
if char_count > maxchars:
From e3f1b74ae5582b72ee37db62a70e84fa85e152e0 Mon Sep 17 00:00:00 2001
From: Frank Schneider
Date: Mon, 23 Jun 2025 11:05:19 +0200
Subject: [PATCH 115/123] Lint reference_algorithms/
---
.../cifar/cifar_jax/submission.py | 4 +-
.../cifar/cifar_pytorch/submission.py | 4 +-
.../cifar/tuning_search_space.json | 10 +--
.../mnist/discrete_space.json | 32 +++++-----
.../mnist/mnist_jax/submission.py | 4 +-
.../mnist/tuning_search_space.json | 6 +-
.../adafactor/jax/sharded_adafactor.py | 2 +-
.../adafactor/jax/submission.py | 4 +-
.../adafactor/pytorch/submission.py | 6 +-
.../adafactor/tuning_search_space.json | 42 ++++++------
.../tuning_search_space_no_beta1.json | 40 ++++++------
.../paper_baselines/adamw/jax/submission.py | 4 +-
.../adamw/pytorch/submission.py | 6 +-
.../adamw/tuning_search_space.json | 48 ++++++++------
.../adamw/tuning_search_space_no_beta1.json | 46 +++++++------
.../paper_baselines/lamb/jax/submission.py | 4 +-
.../lamb/pytorch/submission.py | 6 +-
.../lamb/tuning_search_space.json | 48 ++++++++------
.../lamb/tuning_search_space_no_beta1.json | 46 +++++++------
.../momentum/jax/submission.py | 4 +-
.../momentum/pytorch/submission.py | 2 +-
.../momentum/tuning_search_space.json | 18 ++++--
.../tuning_search_space_no_beta1.json | 16 +++--
.../paper_baselines/nadamw/jax/submission.py | 4 +-
.../nadamw/pytorch/submission.py | 8 +--
.../nadamw/tuning_search_space.json | 20 ++++--
.../nadamw/tuning_search_space_no_beta1.json | 18 ++++--
.../nesterov/jax/submission.py | 4 +-
.../nesterov/pytorch/submission.py | 2 +-
.../nesterov/tuning_search_space.json | 18 ++++--
.../tuning_search_space_no_beta1.json | 16 +++--
.../paper_baselines/sam/jax/submission.py | 4 +-
.../paper_baselines/sam/pytorch/submission.py | 6 +-
.../sam/tuning_search_space.json | 54 +++++++++-------
.../sam/tuning_search_space_no_beta1.json | 52 ++++++++-------
.../shampoo/jax/distributed_shampoo.py | 10 +--
.../paper_baselines/shampoo/jax/submission.py | 4 +-
.../shampoo/tuning_search_space.json | 48 ++++++++------
.../shampoo/tuning_search_space_no_beta1.json | 46 +++++++------
.../cosine_warmup.py | 4 +-
.../tuning_search_space.json | 41 +++++-------
.../tuning_search_space.json | 41 +++++-------
.../criteo1tb_resnet/tuning_search_space.json | 41 +++++-------
.../fastmri/tuning_search_space.json | 56 ++++++----------
.../tuning_search_space.json | 40 +++++-------
.../tuning_search_space.json | 40 +++++-------
.../fastmri_tanh/tuning_search_space.json | 40 +++++-------
.../imagenet_resnet/tuning_search_space.json | 64 +++++++------------
.../tuning_search_space.json | 56 ++++++----------
.../tuning_search_space.json | 56 ++++++----------
.../tuning_search_space.json | 40 +++++-------
.../imagenet_vit/tuning_search_space.json | 40 +++++-------
.../imagenet_vit_glu/tuning_search_space.json | 40 +++++-------
.../imagenet_vit_map/tuning_search_space.json | 40 +++++-------
.../tuning_search_space.json | 40 +++++-------
.../target_setting_algorithms/jax_adamw.py | 14 ++--
.../target_setting_algorithms/jax_momentum.py | 14 ++--
.../target_setting_algorithms/jax_nadamw.py | 14 ++--
.../target_setting_algorithms/jax_nesterov.py | 14 ++--
.../jax_submission_base.py | 2 +-
.../tuning_search_space.json | 40 +++++-------
.../tuning_search_space.json | 40 +++++-------
.../tuning_search_space.json | 40 +++++-------
.../tuning_search_space.json | 40 +++++-------
.../tuning_search_space.json | 40 +++++-------
.../tuning_search_space.json | 40 +++++-------
.../tuning_search_space.json | 40 +++++-------
.../tuning_search_space.json | 40 +++++-------
.../tuning_search_space.json | 40 +++++-------
.../tuning_search_space.json | 40 +++++-------
.../ogbg/tuning_search_space.json | 48 ++++++--------
.../ogbg_gelu/tuning_search_space.json | 42 +++++-------
.../ogbg_model_size/tuning_search_space.json | 42 +++++-------
.../ogbg_silu/tuning_search_space.json | 42 +++++-------
.../pytorch_adamw.py | 12 ++--
.../pytorch_momentum.py | 12 ++--
.../pytorch_nadamw.py | 12 ++--
.../pytorch_nesterov.py | 14 ++--
.../pytorch_submission_base.py | 2 +-
.../wmt/tuning_search_space.json | 48 ++++++--------
.../tuning_search_space.json | 48 ++++++--------
.../wmt_glu_tanh/tuning_search_space.json | 48 ++++++--------
.../wmt_post_ln/tuning_search_space.json | 48 ++++++--------
83 files changed, 988 insertions(+), 1283 deletions(-)
diff --git a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py
index 37f74ac45..d080f2fb3 100644
--- a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py
+++ b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py
@@ -3,11 +3,11 @@
import functools
from typing import Any, Dict, Iterator, List, Optional, Tuple
-from flax import jax_utils
import jax
-from jax import lax
import jax.numpy as jnp
import optax
+from flax import jax_utils
+from jax import lax
from algoperf import spec
diff --git a/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py
index 5fd51c3b2..e8080fe34 100644
--- a/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py
+++ b/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py
@@ -3,9 +3,7 @@
from typing import Any, Dict, Iterator, List, Optional, Tuple
import torch
-from torch.optim.lr_scheduler import CosineAnnealingLR
-from torch.optim.lr_scheduler import LinearLR
-from torch.optim.lr_scheduler import SequentialLR
+from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from algoperf import spec
diff --git a/reference_algorithms/development_algorithms/cifar/tuning_search_space.json b/reference_algorithms/development_algorithms/cifar/tuning_search_space.json
index 283341705..aa8fcacfd 100644
--- a/reference_algorithms/development_algorithms/cifar/tuning_search_space.json
+++ b/reference_algorithms/development_algorithms/cifar/tuning_search_space.json
@@ -1,7 +1,7 @@
{
- "learning_rate": {"feasible_points": [0.1]},
- "warmup_epochs": {"feasible_points": [5]},
- "num_epochs": {"feasible_points": [200]},
- "l2": {"feasible_points": [5e-4]},
- "momentum": {"feasible_points": [0.9]}
+ "learning_rate": { "feasible_points": [0.1] },
+ "warmup_epochs": { "feasible_points": [5] },
+ "num_epochs": { "feasible_points": [200] },
+ "l2": { "feasible_points": [5e-4] },
+ "momentum": { "feasible_points": [0.9] }
}
diff --git a/reference_algorithms/development_algorithms/mnist/discrete_space.json b/reference_algorithms/development_algorithms/mnist/discrete_space.json
index 310f19e7d..8056d4861 100644
--- a/reference_algorithms/development_algorithms/mnist/discrete_space.json
+++ b/reference_algorithms/development_algorithms/mnist/discrete_space.json
@@ -1,17 +1,17 @@
[
- {
- "learning_rate": 1e-3,
- "one_minus_beta_1": 0.999,
- "epsilon": 0.9
- },
- {
- "learning_rate": 1e-2,
- "one_minus_beta_1": 0.99,
- "epsilon": 0.99
- },
- {
- "learning_rate": 1e-1,
- "one_minus_beta_1": 0.9,
- "epsilon": 0.999
- }
-]
\ No newline at end of file
+ {
+ "learning_rate": 1e-3,
+ "one_minus_beta_1": 0.999,
+ "epsilon": 0.9
+ },
+ {
+ "learning_rate": 1e-2,
+ "one_minus_beta_1": 0.99,
+ "epsilon": 0.99
+ },
+ {
+ "learning_rate": 1e-1,
+ "one_minus_beta_1": 0.9,
+ "epsilon": 0.999
+ }
+]
diff --git a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py
index 3ef97577f..afdd1bd43 100644
--- a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py
+++ b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py
@@ -3,11 +3,11 @@
import functools
from typing import Any, Dict, Iterator, List, Optional, Tuple
-from flax import jax_utils
import jax
-from jax import lax
import jax.numpy as jnp
import optax
+from flax import jax_utils
+from jax import lax
from algoperf import spec
diff --git a/reference_algorithms/development_algorithms/mnist/tuning_search_space.json b/reference_algorithms/development_algorithms/mnist/tuning_search_space.json
index 35b941133..bf7d3f1c1 100644
--- a/reference_algorithms/development_algorithms/mnist/tuning_search_space.json
+++ b/reference_algorithms/development_algorithms/mnist/tuning_search_space.json
@@ -1,5 +1,5 @@
{
- "learning_rate": {"min": 1e-4, "max": 1e-2, "scaling": "log"},
- "one_minus_beta_1": {"min": 0.9, "max": 0.999, "scaling": "log"},
- "epsilon": {"feasible_points": [1e-8, 1e-5, 1e-3]}
+ "learning_rate": { "min": 1e-4, "max": 1e-2, "scaling": "log" },
+ "one_minus_beta_1": { "min": 0.9, "max": 0.999, "scaling": "log" },
+ "epsilon": { "feasible_points": [1e-8, 1e-5, 1e-3] }
}
diff --git a/reference_algorithms/paper_baselines/adafactor/jax/sharded_adafactor.py b/reference_algorithms/paper_baselines/adafactor/jax/sharded_adafactor.py
index c83f14a13..32ba97da4 100644
--- a/reference_algorithms/paper_baselines/adafactor/jax/sharded_adafactor.py
+++ b/reference_algorithms/paper_baselines/adafactor/jax/sharded_adafactor.py
@@ -30,8 +30,8 @@
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union
import jax
-from jax import numpy as jnp
import optax
+from jax import numpy as jnp
JTensor = Any
NestedJTensor = Any
diff --git a/reference_algorithms/paper_baselines/adafactor/jax/submission.py b/reference_algorithms/paper_baselines/adafactor/jax/submission.py
index 898de35eb..8dcaa6578 100644
--- a/reference_algorithms/paper_baselines/adafactor/jax/submission.py
+++ b/reference_algorithms/paper_baselines/adafactor/jax/submission.py
@@ -3,11 +3,11 @@
import functools
from typing import Any, Dict, Iterator, List, Optional, Tuple
-from flax import jax_utils
import jax
-from jax import lax
import jax.numpy as jnp
import optax
+from flax import jax_utils
+from jax import lax
from algoperf import spec
from reference_algorithms.paper_baselines.adafactor.jax.sharded_adafactor import (
diff --git a/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py b/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py
index dd831566f..4c96e5562 100644
--- a/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py
+++ b/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py
@@ -3,12 +3,10 @@
from functools import partial
from typing import Any, Dict, Iterator, List, Optional, Tuple
-from absl import logging
import torch
import torch.distributed.nn as dist_nn
-from torch.optim.lr_scheduler import CosineAnnealingLR
-from torch.optim.lr_scheduler import LinearLR
-from torch.optim.lr_scheduler import SequentialLR
+from absl import logging
+from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from algoperf import spec
from algoperf.pytorch_utils import pytorch_setup
diff --git a/reference_algorithms/paper_baselines/adafactor/tuning_search_space.json b/reference_algorithms/paper_baselines/adafactor/tuning_search_space.json
index 5543689ea..37b36e55d 100644
--- a/reference_algorithms/paper_baselines/adafactor/tuning_search_space.json
+++ b/reference_algorithms/paper_baselines/adafactor/tuning_search_space.json
@@ -1,20 +1,26 @@
{
- "learning_rate": {
- "min": 1e-4, "max": 1e-2, "scaling": "log"
- },
- "one_minus_beta1": {
- "min": 1e-2, "max": 0.45, "scaling": "log"
- },
- "warmup_factor": {
- "feasible_points": [0.05]
- },
- "weight_decay": {
- "min": 1e-3, "max": 1.0, "scaling": "log"
- },
- "label_smoothing": {
- "feasible_points": [0.1, 0.2]
- },
- "dropout_rate": {
- "feasible_points": [0.0, 0.1]
- }
+ "learning_rate": {
+ "min": 1e-4,
+ "max": 1e-2,
+ "scaling": "log"
+ },
+ "one_minus_beta1": {
+ "min": 1e-2,
+ "max": 0.45,
+ "scaling": "log"
+ },
+ "warmup_factor": {
+ "feasible_points": [0.05]
+ },
+ "weight_decay": {
+ "min": 1e-3,
+ "max": 1.0,
+ "scaling": "log"
+ },
+ "label_smoothing": {
+ "feasible_points": [0.1, 0.2]
+ },
+ "dropout_rate": {
+ "feasible_points": [0.0, 0.1]
+ }
}
diff --git a/reference_algorithms/paper_baselines/adafactor/tuning_search_space_no_beta1.json b/reference_algorithms/paper_baselines/adafactor/tuning_search_space_no_beta1.json
index 98a506084..35f106f9d 100644
--- a/reference_algorithms/paper_baselines/adafactor/tuning_search_space_no_beta1.json
+++ b/reference_algorithms/paper_baselines/adafactor/tuning_search_space_no_beta1.json
@@ -1,20 +1,24 @@
{
- "learning_rate": {
- "min": 1e-4, "max": 1e-2, "scaling": "log"
- },
- "one_minus_beta1": {
- "feasible_points": [0.1]
- },
- "warmup_factor": {
- "feasible_points": [0.05]
- },
- "weight_decay": {
- "min": 1e-3, "max": 1.0, "scaling": "log"
- },
- "label_smoothing": {
- "feasible_points": [0.1, 0.2]
- },
- "dropout_rate": {
- "feasible_points": [0.0, 0.1]
- }
+ "learning_rate": {
+ "min": 1e-4,
+ "max": 1e-2,
+ "scaling": "log"
+ },
+ "one_minus_beta1": {
+ "feasible_points": [0.1]
+ },
+ "warmup_factor": {
+ "feasible_points": [0.05]
+ },
+ "weight_decay": {
+ "min": 1e-3,
+ "max": 1.0,
+ "scaling": "log"
+ },
+ "label_smoothing": {
+ "feasible_points": [0.1, 0.2]
+ },
+ "dropout_rate": {
+ "feasible_points": [0.0, 0.1]
+ }
}
diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py
index 52c8d5ee2..17d0c2fc2 100644
--- a/reference_algorithms/paper_baselines/adamw/jax/submission.py
+++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py
@@ -3,11 +3,11 @@
import functools
from typing import Any, Dict, Iterator, List, Optional, Tuple
-from flax import jax_utils
import jax
-from jax import lax
import jax.numpy as jnp
import optax
+from flax import jax_utils
+from jax import lax
from algoperf import spec
diff --git a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py
index faefcd254..5df907160 100644
--- a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py
+++ b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py
@@ -2,12 +2,10 @@
from typing import Any, Dict, Iterator, List, Optional, Tuple
-from absl import logging
import torch
import torch.distributed.nn as dist_nn
-from torch.optim.lr_scheduler import CosineAnnealingLR
-from torch.optim.lr_scheduler import LinearLR
-from torch.optim.lr_scheduler import SequentialLR
+from absl import logging
+from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from algoperf import spec
from algoperf.pytorch_utils import pytorch_setup
diff --git a/reference_algorithms/paper_baselines/adamw/tuning_search_space.json b/reference_algorithms/paper_baselines/adamw/tuning_search_space.json
index c96b03eda..abdd6e32d 100644
--- a/reference_algorithms/paper_baselines/adamw/tuning_search_space.json
+++ b/reference_algorithms/paper_baselines/adamw/tuning_search_space.json
@@ -1,23 +1,29 @@
{
- "learning_rate": {
- "min": 1e-4, "max": 1e-2, "scaling": "log"
- },
- "one_minus_beta1": {
- "min": 2e-2, "max": 0.5, "scaling": "log"
- },
- "beta2": {
- "feasible_points": [0.999]
- },
- "warmup_factor": {
- "feasible_points": [0.05]
- },
- "weight_decay": {
- "min": 5e-3, "max": 1.0, "scaling": "log"
- },
- "label_smoothing": {
- "feasible_points": [0.1, 0.2]
- },
- "dropout_rate": {
- "feasible_points": [0.0, 0.1]
- }
+ "learning_rate": {
+ "min": 1e-4,
+ "max": 1e-2,
+ "scaling": "log"
+ },
+ "one_minus_beta1": {
+ "min": 2e-2,
+ "max": 0.5,
+ "scaling": "log"
+ },
+ "beta2": {
+ "feasible_points": [0.999]
+ },
+ "warmup_factor": {
+ "feasible_points": [0.05]
+ },
+ "weight_decay": {
+ "min": 5e-3,
+ "max": 1.0,
+ "scaling": "log"
+ },
+ "label_smoothing": {
+ "feasible_points": [0.1, 0.2]
+ },
+ "dropout_rate": {
+ "feasible_points": [0.0, 0.1]
+ }
}
diff --git a/reference_algorithms/paper_baselines/adamw/tuning_search_space_no_beta1.json b/reference_algorithms/paper_baselines/adamw/tuning_search_space_no_beta1.json
index b8bd2ea49..5a7c27be7 100644
--- a/reference_algorithms/paper_baselines/adamw/tuning_search_space_no_beta1.json
+++ b/reference_algorithms/paper_baselines/adamw/tuning_search_space_no_beta1.json
@@ -1,23 +1,27 @@
{
- "learning_rate": {
- "min": 1e-4, "max": 1e-2, "scaling": "log"
- },
- "one_minus_beta1": {
- "feasible_points": [0.1]
- },
- "beta2": {
- "feasible_points": [0.999]
- },
- "warmup_factor": {
- "feasible_points": [0.05]
- },
- "weight_decay": {
- "min": 5e-3, "max": 1.0, "scaling": "log"
- },
- "label_smoothing": {
- "feasible_points": [0.1, 0.2]
- },
- "dropout_rate": {
- "feasible_points": [0.0, 0.1]
- }
+ "learning_rate": {
+ "min": 1e-4,
+ "max": 1e-2,
+ "scaling": "log"
+ },
+ "one_minus_beta1": {
+ "feasible_points": [0.1]
+ },
+ "beta2": {
+ "feasible_points": [0.999]
+ },
+ "warmup_factor": {
+ "feasible_points": [0.05]
+ },
+ "weight_decay": {
+ "min": 5e-3,
+ "max": 1.0,
+ "scaling": "log"
+ },
+ "label_smoothing": {
+ "feasible_points": [0.1, 0.2]
+ },
+ "dropout_rate": {
+ "feasible_points": [0.0, 0.1]
+ }
}
diff --git a/reference_algorithms/paper_baselines/lamb/jax/submission.py b/reference_algorithms/paper_baselines/lamb/jax/submission.py
index eedbbfc37..168c0579b 100644
--- a/reference_algorithms/paper_baselines/lamb/jax/submission.py
+++ b/reference_algorithms/paper_baselines/lamb/jax/submission.py
@@ -3,11 +3,11 @@
import functools
from typing import Any, Dict, Iterator, List, Optional, Tuple
-from flax import jax_utils
import jax
-from jax import lax
import jax.numpy as jnp
import optax
+from flax import jax_utils
+from jax import lax
from algoperf import spec
diff --git a/reference_algorithms/paper_baselines/lamb/pytorch/submission.py b/reference_algorithms/paper_baselines/lamb/pytorch/submission.py
index 5cda59d6f..c73f89e71 100644
--- a/reference_algorithms/paper_baselines/lamb/pytorch/submission.py
+++ b/reference_algorithms/paper_baselines/lamb/pytorch/submission.py
@@ -3,12 +3,10 @@
import math
from typing import Any, Dict, Iterator, List, Optional, Tuple
-from absl import logging
import torch
+from absl import logging
from torch import Tensor
-from torch.optim.lr_scheduler import CosineAnnealingLR
-from torch.optim.lr_scheduler import LinearLR
-from torch.optim.lr_scheduler import SequentialLR
+from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from algoperf import spec
diff --git a/reference_algorithms/paper_baselines/lamb/tuning_search_space.json b/reference_algorithms/paper_baselines/lamb/tuning_search_space.json
index f2fcde461..7de33fe47 100644
--- a/reference_algorithms/paper_baselines/lamb/tuning_search_space.json
+++ b/reference_algorithms/paper_baselines/lamb/tuning_search_space.json
@@ -1,23 +1,29 @@
{
- "learning_rate": {
- "min": 1e-4, "max": 1e-2, "scaling": "log"
- },
- "one_minus_beta1": {
- "min": 5e-2, "max": 0.3, "scaling": "log"
- },
- "beta2": {
- "feasible_points": [0.999]
- },
- "warmup_factor": {
- "feasible_points": [0.05]
- },
- "weight_decay": {
- "min": 1e-3, "max": 1.0, "scaling": "log"
- },
- "label_smoothing": {
- "feasible_points": [0.1, 0.2]
- },
- "dropout_rate": {
- "feasible_points": [0.0, 0.1]
- }
+ "learning_rate": {
+ "min": 1e-4,
+ "max": 1e-2,
+ "scaling": "log"
+ },
+ "one_minus_beta1": {
+ "min": 5e-2,
+ "max": 0.3,
+ "scaling": "log"
+ },
+ "beta2": {
+ "feasible_points": [0.999]
+ },
+ "warmup_factor": {
+ "feasible_points": [0.05]
+ },
+ "weight_decay": {
+ "min": 1e-3,
+ "max": 1.0,
+ "scaling": "log"
+ },
+ "label_smoothing": {
+ "feasible_points": [0.1, 0.2]
+ },
+ "dropout_rate": {
+ "feasible_points": [0.0, 0.1]
+ }
}
diff --git a/reference_algorithms/paper_baselines/lamb/tuning_search_space_no_beta1.json b/reference_algorithms/paper_baselines/lamb/tuning_search_space_no_beta1.json
index 8934e512d..0f1bc208a 100644
--- a/reference_algorithms/paper_baselines/lamb/tuning_search_space_no_beta1.json
+++ b/reference_algorithms/paper_baselines/lamb/tuning_search_space_no_beta1.json
@@ -1,23 +1,27 @@
{
- "learning_rate": {
- "min": 1e-4, "max": 1e-2, "scaling": "log"
- },
- "one_minus_beta1": {
- "feasible_points": [0.1]
- },
- "beta2": {
- "feasible_points": [0.999]
- },
- "warmup_factor": {
- "feasible_points": [0.05]
- },
- "weight_decay": {
- "min": 1e-3, "max": 1.0, "scaling": "log"
- },
- "label_smoothing": {
- "feasible_points": [0.1, 0.2]
- },
- "dropout_rate": {
- "feasible_points": [0.0, 0.1]
- }
+ "learning_rate": {
+ "min": 1e-4,
+ "max": 1e-2,
+ "scaling": "log"
+ },
+ "one_minus_beta1": {
+ "feasible_points": [0.1]
+ },
+ "beta2": {
+ "feasible_points": [0.999]
+ },
+ "warmup_factor": {
+ "feasible_points": [0.05]
+ },
+ "weight_decay": {
+ "min": 1e-3,
+ "max": 1.0,
+ "scaling": "log"
+ },
+ "label_smoothing": {
+ "feasible_points": [0.1, 0.2]
+ },
+ "dropout_rate": {
+ "feasible_points": [0.0, 0.1]
+ }
}
diff --git a/reference_algorithms/paper_baselines/momentum/jax/submission.py b/reference_algorithms/paper_baselines/momentum/jax/submission.py
index 3540f9415..df084c17b 100644
--- a/reference_algorithms/paper_baselines/momentum/jax/submission.py
+++ b/reference_algorithms/paper_baselines/momentum/jax/submission.py
@@ -3,11 +3,11 @@
import functools
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple
-from flax import jax_utils
import jax
-from jax import lax
import jax.numpy as jnp
import optax
+from flax import jax_utils
+from jax import lax
from algoperf import spec
diff --git a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py
index bad750857..b3d38b3dd 100644
--- a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py
+++ b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py
@@ -2,10 +2,10 @@
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple
-from absl import logging
import optax
import torch
import torch.distributed.nn as dist_nn
+from absl import logging
from torch.optim.lr_scheduler import LambdaLR
from algoperf import spec
diff --git a/reference_algorithms/paper_baselines/momentum/tuning_search_space.json b/reference_algorithms/paper_baselines/momentum/tuning_search_space.json
index 8423bdab7..9ec39a6ef 100644
--- a/reference_algorithms/paper_baselines/momentum/tuning_search_space.json
+++ b/reference_algorithms/paper_baselines/momentum/tuning_search_space.json
@@ -1,21 +1,27 @@
{
"learning_rate": {
- "min": 2e-2, "max": 5.0, "scaling": "log"
+ "min": 2e-2,
+ "max": 5.0,
+ "scaling": "log"
},
"one_minus_beta1": {
- "min": 5e-3, "max": 0.3, "scaling": "log"
+ "min": 5e-3,
+ "max": 0.3,
+ "scaling": "log"
},
"warmup_factor": {
- "feasible_points": [0.05]
+ "feasible_points": [0.05]
},
"weight_decay": {
- "min": 1e-7, "max": 5e-5, "scaling": "log"
+ "min": 1e-7,
+ "max": 5e-5,
+ "scaling": "log"
},
"label_smoothing": {
- "feasible_points": [0.1, 0.2]
+ "feasible_points": [0.1, 0.2]
},
"dropout_rate": {
- "feasible_points": [0.0, 0.1]
+ "feasible_points": [0.0, 0.1]
},
"decay_steps_factor": {
"feasible_points": [0.9]
diff --git a/reference_algorithms/paper_baselines/momentum/tuning_search_space_no_beta1.json b/reference_algorithms/paper_baselines/momentum/tuning_search_space_no_beta1.json
index f874862d8..80f9c7968 100644
--- a/reference_algorithms/paper_baselines/momentum/tuning_search_space_no_beta1.json
+++ b/reference_algorithms/paper_baselines/momentum/tuning_search_space_no_beta1.json
@@ -1,21 +1,25 @@
{
"learning_rate": {
- "min": 2e-2, "max": 5.0, "scaling": "log"
+ "min": 2e-2,
+ "max": 5.0,
+ "scaling": "log"
},
"one_minus_beta1": {
- "feasible_points": [0.1]
+ "feasible_points": [0.1]
},
"warmup_factor": {
- "feasible_points": [0.05]
+ "feasible_points": [0.05]
},
"weight_decay": {
- "min": 1e-7, "max": 5e-5, "scaling": "log"
+ "min": 1e-7,
+ "max": 5e-5,
+ "scaling": "log"
},
"label_smoothing": {
- "feasible_points": [0.1, 0.2]
+ "feasible_points": [0.1, 0.2]
},
"dropout_rate": {
- "feasible_points": [0.0, 0.1]
+ "feasible_points": [0.0, 0.1]
},
"decay_steps_factor": {
"feasible_points": [0.9]
diff --git a/reference_algorithms/paper_baselines/nadamw/jax/submission.py b/reference_algorithms/paper_baselines/nadamw/jax/submission.py
index aa1a08f69..ed721d167 100644
--- a/reference_algorithms/paper_baselines/nadamw/jax/submission.py
+++ b/reference_algorithms/paper_baselines/nadamw/jax/submission.py
@@ -18,11 +18,11 @@
# isort: on
import chex
-from flax import jax_utils
import jax
-from jax import lax
import jax.numpy as jnp
import optax
+from flax import jax_utils
+from jax import lax
from algoperf import spec
diff --git a/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py b/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py
index ecd299988..f6c2faa9d 100644
--- a/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py
+++ b/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py
@@ -3,13 +3,11 @@
import math
from typing import Any, Dict, Iterator, List, Optional, Tuple
-from absl import logging
import torch
-from torch import Tensor
import torch.distributed.nn as dist_nn
-from torch.optim.lr_scheduler import CosineAnnealingLR
-from torch.optim.lr_scheduler import LinearLR
-from torch.optim.lr_scheduler import SequentialLR
+from absl import logging
+from torch import Tensor
+from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from algoperf import spec
from algoperf.pytorch_utils import pytorch_setup
diff --git a/reference_algorithms/paper_baselines/nadamw/tuning_search_space.json b/reference_algorithms/paper_baselines/nadamw/tuning_search_space.json
index cba20c4c2..a3d322771 100644
--- a/reference_algorithms/paper_baselines/nadamw/tuning_search_space.json
+++ b/reference_algorithms/paper_baselines/nadamw/tuning_search_space.json
@@ -1,23 +1,29 @@
{
"learning_rate": {
- "min": 1e-4, "max": 1e-2, "scaling": "log"
+ "min": 1e-4,
+ "max": 1e-2,
+ "scaling": "log"
},
"one_minus_beta1": {
- "min": 4e-3, "max": 0.1, "scaling": "log"
+ "min": 4e-3,
+ "max": 0.1,
+ "scaling": "log"
},
"beta2": {
- "feasible_points": [0.999]
+ "feasible_points": [0.999]
},
"warmup_factor": {
- "feasible_points": [0.05]
+ "feasible_points": [0.05]
},
"weight_decay": {
- "min": 5e-3, "max": 1.0, "scaling": "log"
+ "min": 5e-3,
+ "max": 1.0,
+ "scaling": "log"
},
"label_smoothing": {
- "feasible_points": [0.1, 0.2]
+ "feasible_points": [0.1, 0.2]
},
"dropout_rate": {
- "feasible_points": [0.0, 0.1]
+ "feasible_points": [0.0, 0.1]
}
}
diff --git a/reference_algorithms/paper_baselines/nadamw/tuning_search_space_no_beta1.json b/reference_algorithms/paper_baselines/nadamw/tuning_search_space_no_beta1.json
index 58973eb27..5a7c27be7 100644
--- a/reference_algorithms/paper_baselines/nadamw/tuning_search_space_no_beta1.json
+++ b/reference_algorithms/paper_baselines/nadamw/tuning_search_space_no_beta1.json
@@ -1,23 +1,27 @@
{
"learning_rate": {
- "min": 1e-4, "max": 1e-2, "scaling": "log"
+ "min": 1e-4,
+ "max": 1e-2,
+ "scaling": "log"
},
"one_minus_beta1": {
- "feasible_points": [0.1]
+ "feasible_points": [0.1]
},
"beta2": {
- "feasible_points": [0.999]
+ "feasible_points": [0.999]
},
"warmup_factor": {
- "feasible_points": [0.05]
+ "feasible_points": [0.05]
},
"weight_decay": {
- "min": 5e-3, "max": 1.0, "scaling": "log"
+ "min": 5e-3,
+ "max": 1.0,
+ "scaling": "log"
},
"label_smoothing": {
- "feasible_points": [0.1, 0.2]
+ "feasible_points": [0.1, 0.2]
},
"dropout_rate": {
- "feasible_points": [0.0, 0.1]
+ "feasible_points": [0.0, 0.1]
}
}
diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py
index d32026212..18e58b3c0 100644
--- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py
+++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py
@@ -3,11 +3,11 @@
import functools
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple
-from flax import jax_utils
import jax
-from jax import lax
import jax.numpy as jnp
import optax
+from flax import jax_utils
+from jax import lax
from algoperf import spec
diff --git a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py
index 77361f472..9d3bfa6e7 100644
--- a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py
+++ b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py
@@ -2,10 +2,10 @@
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple
-from absl import logging
import optax
import torch
import torch.distributed.nn as dist_nn
+from absl import logging
from torch.optim.lr_scheduler import LambdaLR
from algoperf import spec
diff --git a/reference_algorithms/paper_baselines/nesterov/tuning_search_space.json b/reference_algorithms/paper_baselines/nesterov/tuning_search_space.json
index 8423bdab7..9ec39a6ef 100644
--- a/reference_algorithms/paper_baselines/nesterov/tuning_search_space.json
+++ b/reference_algorithms/paper_baselines/nesterov/tuning_search_space.json
@@ -1,21 +1,27 @@
{
"learning_rate": {
- "min": 2e-2, "max": 5.0, "scaling": "log"
+ "min": 2e-2,
+ "max": 5.0,
+ "scaling": "log"
},
"one_minus_beta1": {
- "min": 5e-3, "max": 0.3, "scaling": "log"
+ "min": 5e-3,
+ "max": 0.3,
+ "scaling": "log"
},
"warmup_factor": {
- "feasible_points": [0.05]
+ "feasible_points": [0.05]
},
"weight_decay": {
- "min": 1e-7, "max": 5e-5, "scaling": "log"
+ "min": 1e-7,
+ "max": 5e-5,
+ "scaling": "log"
},
"label_smoothing": {
- "feasible_points": [0.1, 0.2]
+ "feasible_points": [0.1, 0.2]
},
"dropout_rate": {
- "feasible_points": [0.0, 0.1]
+ "feasible_points": [0.0, 0.1]
},
"decay_steps_factor": {
"feasible_points": [0.9]
diff --git a/reference_algorithms/paper_baselines/nesterov/tuning_search_space_no_beta1.json b/reference_algorithms/paper_baselines/nesterov/tuning_search_space_no_beta1.json
index f874862d8..80f9c7968 100644
--- a/reference_algorithms/paper_baselines/nesterov/tuning_search_space_no_beta1.json
+++ b/reference_algorithms/paper_baselines/nesterov/tuning_search_space_no_beta1.json
@@ -1,21 +1,25 @@
{
"learning_rate": {
- "min": 2e-2, "max": 5.0, "scaling": "log"
+ "min": 2e-2,
+ "max": 5.0,
+ "scaling": "log"
},
"one_minus_beta1": {
- "feasible_points": [0.1]
+ "feasible_points": [0.1]
},
"warmup_factor": {
- "feasible_points": [0.05]
+ "feasible_points": [0.05]
},
"weight_decay": {
- "min": 1e-7, "max": 5e-5, "scaling": "log"
+ "min": 1e-7,
+ "max": 5e-5,
+ "scaling": "log"
},
"label_smoothing": {
- "feasible_points": [0.1, 0.2]
+ "feasible_points": [0.1, 0.2]
},
"dropout_rate": {
- "feasible_points": [0.0, 0.1]
+ "feasible_points": [0.0, 0.1]
},
"decay_steps_factor": {
"feasible_points": [0.9]
diff --git a/reference_algorithms/paper_baselines/sam/jax/submission.py b/reference_algorithms/paper_baselines/sam/jax/submission.py
index ce6db3ac3..9ab193b56 100644
--- a/reference_algorithms/paper_baselines/sam/jax/submission.py
+++ b/reference_algorithms/paper_baselines/sam/jax/submission.py
@@ -3,11 +3,11 @@
import functools
from typing import Any, Dict, Iterator, List, Optional, Tuple
-from flax import jax_utils
import jax
-from jax import lax
import jax.numpy as jnp
import optax
+from flax import jax_utils
+from jax import lax
from algoperf import spec
diff --git a/reference_algorithms/paper_baselines/sam/pytorch/submission.py b/reference_algorithms/paper_baselines/sam/pytorch/submission.py
index fdd4eb8b7..652ebed1d 100644
--- a/reference_algorithms/paper_baselines/sam/pytorch/submission.py
+++ b/reference_algorithms/paper_baselines/sam/pytorch/submission.py
@@ -2,12 +2,10 @@
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple
-from absl import logging
import torch
import torch.distributed.nn as dist_nn
-from torch.optim.lr_scheduler import CosineAnnealingLR
-from torch.optim.lr_scheduler import LinearLR
-from torch.optim.lr_scheduler import SequentialLR
+from absl import logging
+from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from algoperf import spec
from algoperf.pytorch_utils import pytorch_setup
diff --git a/reference_algorithms/paper_baselines/sam/tuning_search_space.json b/reference_algorithms/paper_baselines/sam/tuning_search_space.json
index 66dae232b..f32058937 100644
--- a/reference_algorithms/paper_baselines/sam/tuning_search_space.json
+++ b/reference_algorithms/paper_baselines/sam/tuning_search_space.json
@@ -1,26 +1,32 @@
{
- "learning_rate": {
- "min": 1e-4, "max": 1e-2, "scaling": "log"
- },
- "one_minus_beta1": {
- "min": 5e-2, "max": 0.43, "scaling": "log"
- },
- "beta2": {
- "feasible_points": [0.999]
- },
- "warmup_factor": {
- "feasible_points": [0.05]
- },
- "weight_decay": {
- "min": 1e-2, "max": 0.2, "scaling": "log"
- },
- "label_smoothing": {
- "feasible_points": [0.1, 0.2]
- },
- "dropout_rate": {
- "feasible_points": [0.0, 0.1]
- },
- "rho": {
- "feasible_points": [0.01, 0.02, 0.05]
- }
+ "learning_rate": {
+ "min": 1e-4,
+ "max": 1e-2,
+ "scaling": "log"
+ },
+ "one_minus_beta1": {
+ "min": 5e-2,
+ "max": 0.43,
+ "scaling": "log"
+ },
+ "beta2": {
+ "feasible_points": [0.999]
+ },
+ "warmup_factor": {
+ "feasible_points": [0.05]
+ },
+ "weight_decay": {
+ "min": 1e-2,
+ "max": 0.2,
+ "scaling": "log"
+ },
+ "label_smoothing": {
+ "feasible_points": [0.1, 0.2]
+ },
+ "dropout_rate": {
+ "feasible_points": [0.0, 0.1]
+ },
+ "rho": {
+ "feasible_points": [0.01, 0.02, 0.05]
+ }
}
diff --git a/reference_algorithms/paper_baselines/sam/tuning_search_space_no_beta1.json b/reference_algorithms/paper_baselines/sam/tuning_search_space_no_beta1.json
index 89c480e7a..ee4e0c3e4 100644
--- a/reference_algorithms/paper_baselines/sam/tuning_search_space_no_beta1.json
+++ b/reference_algorithms/paper_baselines/sam/tuning_search_space_no_beta1.json
@@ -1,26 +1,30 @@
{
- "learning_rate": {
- "min": 1e-4, "max": 1e-2, "scaling": "log"
- },
- "one_minus_beta1": {
- "feasible_points": [0.1]
- },
- "beta2": {
- "feasible_points": [0.999]
- },
- "warmup_factor": {
- "feasible_points": [0.05]
- },
- "weight_decay": {
- "min": 1e-2, "max": 0.2, "scaling": "log"
- },
- "label_smoothing": {
- "feasible_points": [0.1, 0.2]
- },
- "dropout_rate": {
- "feasible_points": [0.0, 0.1]
- },
- "rho": {
- "feasible_points": [0.01, 0.02, 0.05]
- }
+ "learning_rate": {
+ "min": 1e-4,
+ "max": 1e-2,
+ "scaling": "log"
+ },
+ "one_minus_beta1": {
+ "feasible_points": [0.1]
+ },
+ "beta2": {
+ "feasible_points": [0.999]
+ },
+ "warmup_factor": {
+ "feasible_points": [0.05]
+ },
+ "weight_decay": {
+ "min": 1e-2,
+ "max": 0.2,
+ "scaling": "log"
+ },
+ "label_smoothing": {
+ "feasible_points": [0.1, 0.2]
+ },
+ "dropout_rate": {
+ "feasible_points": [0.0, 0.1]
+ },
+ "rho": {
+ "feasible_points": [0.01, 0.02, 0.05]
+ }
}
diff --git a/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py b/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py
index 830dd4816..c719361d3 100644
--- a/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py
+++ b/reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py
@@ -36,17 +36,17 @@
import functools
import itertools
import logging
-from typing import Any, cast, List, NamedTuple, Optional, TypeVar, Union
+from typing import Any, List, NamedTuple, Optional, TypeVar, Union, cast
import chex
-from flax import struct
import jax
-from jax import lax
-from jax.experimental import pjit
-from jax.experimental.sparse import linalg
import jax.numpy as jnp
import numpy as np
import optax
+from flax import struct
+from jax import lax
+from jax.experimental import pjit
+from jax.experimental.sparse import linalg
# Dtype for inverse-pth root routine
# Switch to f64 if you have hardware that supports it. Enable the jax flag
diff --git a/reference_algorithms/paper_baselines/shampoo/jax/submission.py b/reference_algorithms/paper_baselines/shampoo/jax/submission.py
index 526afe7d5..8bf4d2dc5 100644
--- a/reference_algorithms/paper_baselines/shampoo/jax/submission.py
+++ b/reference_algorithms/paper_baselines/shampoo/jax/submission.py
@@ -3,11 +3,11 @@
import functools
from typing import Any, Dict, Iterator, List, Optional, Tuple
-from flax import jax_utils
import jax
-from jax import lax
import jax.numpy as jnp
import optax
+from flax import jax_utils
+from jax import lax
from algoperf import spec
from reference_algorithms.paper_baselines.shampoo.jax.distributed_shampoo import (
diff --git a/reference_algorithms/paper_baselines/shampoo/tuning_search_space.json b/reference_algorithms/paper_baselines/shampoo/tuning_search_space.json
index 9d804ba0e..58f6f4fd1 100644
--- a/reference_algorithms/paper_baselines/shampoo/tuning_search_space.json
+++ b/reference_algorithms/paper_baselines/shampoo/tuning_search_space.json
@@ -1,23 +1,29 @@
{
- "learning_rate": {
- "min": 1e-4, "max": 1e-2, "scaling": "log"
- },
- "one_minus_beta1": {
- "min": 1e-2, "max": 0.15, "scaling": "log"
- },
- "beta2": {
- "feasible_points": [0.999]
- },
- "warmup_factor": {
- "feasible_points": [0.05]
- },
- "weight_decay": {
- "min": 5e-3, "max": 1.0, "scaling": "log"
- },
- "label_smoothing": {
- "feasible_points": [0.1, 0.2]
- },
- "dropout_rate": {
- "feasible_points": [0.0, 0.1]
- }
+ "learning_rate": {
+ "min": 1e-4,
+ "max": 1e-2,
+ "scaling": "log"
+ },
+ "one_minus_beta1": {
+ "min": 1e-2,
+ "max": 0.15,
+ "scaling": "log"
+ },
+ "beta2": {
+ "feasible_points": [0.999]
+ },
+ "warmup_factor": {
+ "feasible_points": [0.05]
+ },
+ "weight_decay": {
+ "min": 5e-3,
+ "max": 1.0,
+ "scaling": "log"
+ },
+ "label_smoothing": {
+ "feasible_points": [0.1, 0.2]
+ },
+ "dropout_rate": {
+ "feasible_points": [0.0, 0.1]
+ }
}
diff --git a/reference_algorithms/paper_baselines/shampoo/tuning_search_space_no_beta1.json b/reference_algorithms/paper_baselines/shampoo/tuning_search_space_no_beta1.json
index b8bd2ea49..5a7c27be7 100644
--- a/reference_algorithms/paper_baselines/shampoo/tuning_search_space_no_beta1.json
+++ b/reference_algorithms/paper_baselines/shampoo/tuning_search_space_no_beta1.json
@@ -1,23 +1,27 @@
{
- "learning_rate": {
- "min": 1e-4, "max": 1e-2, "scaling": "log"
- },
- "one_minus_beta1": {
- "feasible_points": [0.1]
- },
- "beta2": {
- "feasible_points": [0.999]
- },
- "warmup_factor": {
- "feasible_points": [0.05]
- },
- "weight_decay": {
- "min": 5e-3, "max": 1.0, "scaling": "log"
- },
- "label_smoothing": {
- "feasible_points": [0.1, 0.2]
- },
- "dropout_rate": {
- "feasible_points": [0.0, 0.1]
- }
+ "learning_rate": {
+ "min": 1e-4,
+ "max": 1e-2,
+ "scaling": "log"
+ },
+ "one_minus_beta1": {
+ "feasible_points": [0.1]
+ },
+ "beta2": {
+ "feasible_points": [0.999]
+ },
+ "warmup_factor": {
+ "feasible_points": [0.05]
+ },
+ "weight_decay": {
+ "min": 5e-3,
+ "max": 1.0,
+ "scaling": "log"
+ },
+ "label_smoothing": {
+ "feasible_points": [0.1, 0.2]
+ },
+ "dropout_rate": {
+ "feasible_points": [0.0, 0.1]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/cosine_warmup.py b/reference_algorithms/target_setting_algorithms/cosine_warmup.py
index eeb87cd87..6a2241732 100644
--- a/reference_algorithms/target_setting_algorithms/cosine_warmup.py
+++ b/reference_algorithms/target_setting_algorithms/cosine_warmup.py
@@ -1,9 +1,7 @@
"""Implementions of a linear warmup then cosine decay LR schedule."""
import optax
-from torch.optim.lr_scheduler import CosineAnnealingLR
-from torch.optim.lr_scheduler import LinearLR
-from torch.optim.lr_scheduler import SequentialLR
+from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
def jax_cosine_warmup(step_hint: int, hyperparameters):
diff --git a/reference_algorithms/target_setting_algorithms/criteo1tb_embed_init/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/criteo1tb_embed_init/tuning_search_space.json
index bd6c9702f..110138607 100644
--- a/reference_algorithms/target_setting_algorithms/criteo1tb_embed_init/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/criteo1tb_embed_init/tuning_search_space.json
@@ -1,28 +1,17 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.002517072211464665
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.9908351643533544
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.9859568907533993
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 799
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.12274552870237089
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.002517072211464665]
+ },
+ "beta1": {
+ "feasible_points": [0.9908351643533544]
+ },
+ "beta2": {
+ "feasible_points": [0.9859568907533993]
+ },
+ "warmup_steps": {
+ "feasible_points": [799]
+ },
+ "weight_decay": {
+ "feasible_points": [0.12274552870237089]
}
-
\ No newline at end of file
+}
diff --git a/reference_algorithms/target_setting_algorithms/criteo1tb_layernorm/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/criteo1tb_layernorm/tuning_search_space.json
index 8d128dae1..a7f52681d 100644
--- a/reference_algorithms/target_setting_algorithms/criteo1tb_layernorm/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/criteo1tb_layernorm/tuning_search_space.json
@@ -1,28 +1,17 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.05493199486120455
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.954922991734919
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.9986188074995163
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 799
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.00011065469792077193
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.05493199486120455]
+ },
+ "beta1": {
+ "feasible_points": [0.954922991734919]
+ },
+ "beta2": {
+ "feasible_points": [0.9986188074995163]
+ },
+ "warmup_steps": {
+ "feasible_points": [799]
+ },
+ "weight_decay": {
+ "feasible_points": [0.00011065469792077193]
}
-
\ No newline at end of file
+}
diff --git a/reference_algorithms/target_setting_algorithms/criteo1tb_resnet/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/criteo1tb_resnet/tuning_search_space.json
index a33ae2ff5..31ce92bd1 100644
--- a/reference_algorithms/target_setting_algorithms/criteo1tb_resnet/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/criteo1tb_resnet/tuning_search_space.json
@@ -1,28 +1,17 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.001493629901423942
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.9592129978682067
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.9824918272399145
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 399
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.00038587516415285595
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.001493629901423942]
+ },
+ "beta1": {
+ "feasible_points": [0.9592129978682067]
+ },
+ "beta2": {
+ "feasible_points": [0.9824918272399145]
+ },
+ "warmup_steps": {
+ "feasible_points": [399]
+ },
+ "weight_decay": {
+ "feasible_points": [0.00038587516415285595]
}
-
\ No newline at end of file
+}
diff --git a/reference_algorithms/target_setting_algorithms/fastmri/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/fastmri/tuning_search_space.json
index d8b4ed1b9..894ebb9fb 100644
--- a/reference_algorithms/target_setting_algorithms/fastmri/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/fastmri/tuning_search_space.json
@@ -1,37 +1,23 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.028609
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.981543
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.9978504782314613
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 1357
- ]
- },
- "decay_steps_factor": {
- "feasible_points": [
- 0.984398
- ]
- },
- "end_factor": {
- "feasible_points": [
- 0.01
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.000576
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.028609]
+ },
+ "beta1": {
+ "feasible_points": [0.981543]
+ },
+ "beta2": {
+ "feasible_points": [0.9978504782314613]
+ },
+ "warmup_steps": {
+ "feasible_points": [1357]
+ },
+ "decay_steps_factor": {
+ "feasible_points": [0.984398]
+ },
+ "end_factor": {
+ "feasible_points": [0.01]
+ },
+ "weight_decay": {
+ "feasible_points": [0.000576]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/fastmri_layernorm/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/fastmri_layernorm/tuning_search_space.json
index 7a49ea891..a3aa8ea08 100644
--- a/reference_algorithms/target_setting_algorithms/fastmri_layernorm/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/fastmri_layernorm/tuning_search_space.json
@@ -1,27 +1,17 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.008334676559764446
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.8294338711079317
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.8551723332825868
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 2714
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.01371235755699044
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.008334676559764446]
+ },
+ "beta1": {
+ "feasible_points": [0.8294338711079317]
+ },
+ "beta2": {
+ "feasible_points": [0.8551723332825868]
+ },
+ "warmup_steps": {
+ "feasible_points": [2714]
+ },
+ "weight_decay": {
+ "feasible_points": [0.01371235755699044]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/fastmri_model_size/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/fastmri_model_size/tuning_search_space.json
index 5516242df..21c5ac87c 100644
--- a/reference_algorithms/target_setting_algorithms/fastmri_model_size/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/fastmri_model_size/tuning_search_space.json
@@ -1,27 +1,17 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.006173154695175443
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.8496694604806512
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.4639437428687345
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 1357
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.1679001017957879
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.006173154695175443]
+ },
+ "beta1": {
+ "feasible_points": [0.8496694604806512]
+ },
+ "beta2": {
+ "feasible_points": [0.4639437428687345]
+ },
+ "warmup_steps": {
+ "feasible_points": [1357]
+ },
+ "weight_decay": {
+ "feasible_points": [0.1679001017957879]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/fastmri_tanh/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/fastmri_tanh/tuning_search_space.json
index c3f06a686..59d624fe9 100644
--- a/reference_algorithms/target_setting_algorithms/fastmri_tanh/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/fastmri_tanh/tuning_search_space.json
@@ -1,27 +1,17 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.04037951750205473
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.9932215932637941
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.9425306939334134
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 542
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.14877061239151607
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.04037951750205473]
+ },
+ "beta1": {
+ "feasible_points": [0.9932215932637941]
+ },
+ "beta2": {
+ "feasible_points": [0.9425306939334134]
+ },
+ "warmup_steps": {
+ "feasible_points": [542]
+ },
+ "weight_decay": {
+ "feasible_points": [0.14877061239151607]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json
index 649487c48..941bac70e 100644
--- a/reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/imagenet_resnet/tuning_search_space.json
@@ -1,42 +1,26 @@
{
- "learning_rate": {
- "feasible_points": [
- 4.131896390902391
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.9274758113254791
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.9978504782314613
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 6999
- ]
- },
- "decay_steps_factor": {
- "feasible_points": [
- 0.9007765761611038
- ]
- },
- "end_factor": {
- "feasible_points": [
- 0.001
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 5.6687777311501786e-6
- ]
- },
- "label_smoothing": {
- "feasible_points": [
- 0.2
- ]
- }
+ "learning_rate": {
+ "feasible_points": [4.131896390902391]
+ },
+ "beta1": {
+ "feasible_points": [0.9274758113254791]
+ },
+ "beta2": {
+ "feasible_points": [0.9978504782314613]
+ },
+ "warmup_steps": {
+ "feasible_points": [6999]
+ },
+ "decay_steps_factor": {
+ "feasible_points": [0.9007765761611038]
+ },
+ "end_factor": {
+ "feasible_points": [0.001]
+ },
+ "weight_decay": {
+ "feasible_points": [5.6687777311501786e-6]
+ },
+ "label_smoothing": {
+ "feasible_points": [0.2]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/imagenet_resnet_gelu/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/imagenet_resnet_gelu/tuning_search_space.json
index 6524f5a5b..f72a4057d 100644
--- a/reference_algorithms/target_setting_algorithms/imagenet_resnet_gelu/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/imagenet_resnet_gelu/tuning_search_space.json
@@ -1,37 +1,23 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.3850582234619253
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.9845129495436189
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 6999
- ]
- },
- "decay_steps_factor": {
- "feasible_points": [
- 0.9504205232618159
- ]
- },
- "end_factor": {
- "feasible_points": [
- 0.001
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 1.7359160785435053e-5
- ]
- },
- "label_smoothing": {
- "feasible_points": [
- 0.2
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.3850582234619253]
+ },
+ "beta1": {
+ "feasible_points": [0.9845129495436189]
+ },
+ "warmup_steps": {
+ "feasible_points": [6999]
+ },
+ "decay_steps_factor": {
+ "feasible_points": [0.9504205232618159]
+ },
+ "end_factor": {
+ "feasible_points": [0.001]
+ },
+ "weight_decay": {
+ "feasible_points": [1.7359160785435053e-5]
+ },
+ "label_smoothing": {
+ "feasible_points": [0.2]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/imagenet_resnet_large_bn_init/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/imagenet_resnet_large_bn_init/tuning_search_space.json
index 7ad32bb60..f58474cc8 100644
--- a/reference_algorithms/target_setting_algorithms/imagenet_resnet_large_bn_init/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/imagenet_resnet_large_bn_init/tuning_search_space.json
@@ -1,37 +1,23 @@
{
- "learning_rate": {
- "feasible_points": [
- 4.131896390902391
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.9274758113254791
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 6999
- ]
- },
- "decay_steps_factor": {
- "feasible_points": [
- 0.9007765761611038
- ]
- },
- "end_factor": {
- "feasible_points": [
- 0.01
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 5.6687777311501786e-6
- ]
- },
- "label_smoothing": {
- "feasible_points": [
- 0.2
- ]
- }
+ "learning_rate": {
+ "feasible_points": [4.131896390902391]
+ },
+ "beta1": {
+ "feasible_points": [0.9274758113254791]
+ },
+ "warmup_steps": {
+ "feasible_points": [6999]
+ },
+ "decay_steps_factor": {
+ "feasible_points": [0.9007765761611038]
+ },
+ "end_factor": {
+ "feasible_points": [0.01]
+ },
+ "weight_decay": {
+ "feasible_points": [5.6687777311501786e-6]
+ },
+ "label_smoothing": {
+ "feasible_points": [0.2]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/imagenet_resnet_silu/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/imagenet_resnet_silu/tuning_search_space.json
index 4556c6235..c63a87214 100644
--- a/reference_algorithms/target_setting_algorithms/imagenet_resnet_silu/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/imagenet_resnet_silu/tuning_search_space.json
@@ -1,27 +1,17 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.01897755400372091
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.9666072782043229
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.99681600289198
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 6999
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.015653883841116094
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.01897755400372091]
+ },
+ "beta1": {
+ "feasible_points": [0.9666072782043229]
+ },
+ "beta2": {
+ "feasible_points": [0.99681600289198]
+ },
+ "warmup_steps": {
+ "feasible_points": [6999]
+ },
+ "weight_decay": {
+ "feasible_points": [0.015653883841116094]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/imagenet_vit/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/imagenet_vit/tuning_search_space.json
index 98360bdff..6c7501295 100644
--- a/reference_algorithms/target_setting_algorithms/imagenet_vit/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/imagenet_vit/tuning_search_space.json
@@ -1,27 +1,17 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.0008445074561975979
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.8895758153482813
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.9978504782314613
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 6999
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.08135402759553023
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.0008445074561975979]
+ },
+ "beta1": {
+ "feasible_points": [0.8895758153482813]
+ },
+ "beta2": {
+ "feasible_points": [0.9978504782314613]
+ },
+ "warmup_steps": {
+ "feasible_points": [6999]
+ },
+ "weight_decay": {
+ "feasible_points": [0.08135402759553023]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/imagenet_vit_glu/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/imagenet_vit_glu/tuning_search_space.json
index 98360bdff..6c7501295 100644
--- a/reference_algorithms/target_setting_algorithms/imagenet_vit_glu/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/imagenet_vit_glu/tuning_search_space.json
@@ -1,27 +1,17 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.0008445074561975979
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.8895758153482813
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.9978504782314613
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 6999
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.08135402759553023
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.0008445074561975979]
+ },
+ "beta1": {
+ "feasible_points": [0.8895758153482813]
+ },
+ "beta2": {
+ "feasible_points": [0.9978504782314613]
+ },
+ "warmup_steps": {
+ "feasible_points": [6999]
+ },
+ "weight_decay": {
+ "feasible_points": [0.08135402759553023]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/imagenet_vit_map/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/imagenet_vit_map/tuning_search_space.json
index 98360bdff..6c7501295 100644
--- a/reference_algorithms/target_setting_algorithms/imagenet_vit_map/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/imagenet_vit_map/tuning_search_space.json
@@ -1,27 +1,17 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.0008445074561975979
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.8895758153482813
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.9978504782314613
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 6999
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.08135402759553023
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.0008445074561975979]
+ },
+ "beta1": {
+ "feasible_points": [0.8895758153482813]
+ },
+ "beta2": {
+ "feasible_points": [0.9978504782314613]
+ },
+ "warmup_steps": {
+ "feasible_points": [6999]
+ },
+ "weight_decay": {
+ "feasible_points": [0.08135402759553023]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/imagenet_vit_post_ln/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/imagenet_vit_post_ln/tuning_search_space.json
index d6f2053ff..94711417c 100644
--- a/reference_algorithms/target_setting_algorithms/imagenet_vit_post_ln/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/imagenet_vit_post_ln/tuning_search_space.json
@@ -1,27 +1,17 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.00026032497966327757
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.9709035036599892
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.6572080806975734
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 13999
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.03077045727617869
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.00026032497966327757]
+ },
+ "beta1": {
+ "feasible_points": [0.9709035036599892]
+ },
+ "beta2": {
+ "feasible_points": [0.6572080806975734]
+ },
+ "warmup_steps": {
+ "feasible_points": [13999]
+ },
+ "weight_decay": {
+ "feasible_points": [0.03077045727617869]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/jax_adamw.py b/reference_algorithms/target_setting_algorithms/jax_adamw.py
index e6f8d915c..3fa5e4955 100644
--- a/reference_algorithms/target_setting_algorithms/jax_adamw.py
+++ b/reference_algorithms/target_setting_algorithms/jax_adamw.py
@@ -1,21 +1,21 @@
"""Submission file for an AdamW optimizer with warmup+cosine LR in Jax."""
-from flax import jax_utils
import jax
import jax.numpy as jnp
import optax
+from flax import jax_utils
from algoperf import spec
from reference_algorithms.target_setting_algorithms import cosine_warmup
-from reference_algorithms.target_setting_algorithms.data_selection import (
+from reference_algorithms.target_setting_algorithms.data_selection import ( # noqa: F401
data_selection,
-) # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.get_batch_size import (
+)
+from reference_algorithms.target_setting_algorithms.get_batch_size import ( # noqa: F401
get_batch_size,
-) # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.jax_submission_base import (
+)
+from reference_algorithms.target_setting_algorithms.jax_submission_base import ( # noqa: F401
update_params,
-) # pylint: disable=unused-import
+)
def init_optimizer_state(
diff --git a/reference_algorithms/target_setting_algorithms/jax_momentum.py b/reference_algorithms/target_setting_algorithms/jax_momentum.py
index b37464c1c..92d3f3a8f 100644
--- a/reference_algorithms/target_setting_algorithms/jax_momentum.py
+++ b/reference_algorithms/target_setting_algorithms/jax_momentum.py
@@ -2,21 +2,21 @@
from typing import Callable
-from flax import jax_utils
import jax
import jax.numpy as jnp
import optax
+from flax import jax_utils
from algoperf import spec
-from reference_algorithms.target_setting_algorithms.data_selection import (
+from reference_algorithms.target_setting_algorithms.data_selection import ( # noqa: F401
data_selection,
-) # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.get_batch_size import (
+)
+from reference_algorithms.target_setting_algorithms.get_batch_size import ( # noqa: F401
get_batch_size,
-) # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.jax_submission_base import (
+)
+from reference_algorithms.target_setting_algorithms.jax_submission_base import ( # noqa: F401
update_params,
-) # pylint: disable=unused-import
+)
def init_optimizer_state(
diff --git a/reference_algorithms/target_setting_algorithms/jax_nadamw.py b/reference_algorithms/target_setting_algorithms/jax_nadamw.py
index 7d4a1fcc1..6883b17ab 100644
--- a/reference_algorithms/target_setting_algorithms/jax_nadamw.py
+++ b/reference_algorithms/target_setting_algorithms/jax_nadamw.py
@@ -3,22 +3,22 @@
from typing import Any, Callable, NamedTuple, Optional, Union
import chex
-from flax import jax_utils
import jax
import jax.numpy as jnp
import optax
+from flax import jax_utils
from algoperf import spec
from reference_algorithms.target_setting_algorithms import cosine_warmup
-from reference_algorithms.target_setting_algorithms.data_selection import (
+from reference_algorithms.target_setting_algorithms.data_selection import ( # noqa: F401
data_selection,
-) # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.get_batch_size import (
+)
+from reference_algorithms.target_setting_algorithms.get_batch_size import ( # noqa: F401
get_batch_size,
-) # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.jax_submission_base import (
+)
+from reference_algorithms.target_setting_algorithms.jax_submission_base import ( # noqa: F401
update_params,
-) # pylint: disable=unused-import
+)
# Forked from
diff --git a/reference_algorithms/target_setting_algorithms/jax_nesterov.py b/reference_algorithms/target_setting_algorithms/jax_nesterov.py
index 714cbb225..7f4d1cd86 100644
--- a/reference_algorithms/target_setting_algorithms/jax_nesterov.py
+++ b/reference_algorithms/target_setting_algorithms/jax_nesterov.py
@@ -2,21 +2,21 @@
from typing import Callable
-from flax import jax_utils
import jax
import jax.numpy as jnp
import optax
+from flax import jax_utils
from algoperf import spec
-from reference_algorithms.target_setting_algorithms.data_selection import (
+from reference_algorithms.target_setting_algorithms.data_selection import ( # noqa: F401
data_selection,
-) # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.get_batch_size import (
+)
+from reference_algorithms.target_setting_algorithms.get_batch_size import ( # noqa: F401
get_batch_size,
-) # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.jax_submission_base import (
+)
+from reference_algorithms.target_setting_algorithms.jax_submission_base import ( # noqa: F401
update_params,
-) # pylint: disable=unused-import
+)
def init_optimizer_state(
diff --git a/reference_algorithms/target_setting_algorithms/jax_submission_base.py b/reference_algorithms/target_setting_algorithms/jax_submission_base.py
index 557e6957c..d9b12f5ca 100644
--- a/reference_algorithms/target_setting_algorithms/jax_submission_base.py
+++ b/reference_algorithms/target_setting_algorithms/jax_submission_base.py
@@ -4,9 +4,9 @@
from typing import Any, Dict, List, Optional, Tuple
import jax
-from jax import lax
import jax.numpy as jnp
import optax
+from jax import lax
from algoperf import spec
diff --git a/reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json
index 482a28931..936833cf3 100644
--- a/reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json
@@ -1,27 +1,17 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.002106913873888147
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.8231189937738506
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.8774571227688758
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 1199
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.27590534177690645
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.002106913873888147]
+ },
+ "beta1": {
+ "feasible_points": [0.8231189937738506]
+ },
+ "beta2": {
+ "feasible_points": [0.8774571227688758]
+ },
+ "warmup_steps": {
+ "feasible_points": [1199]
+ },
+ "weight_decay": {
+ "feasible_points": [0.27590534177690645]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/librispeech_conformer_attention_temperature/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_conformer_attention_temperature/tuning_search_space.json
index 22f3376b4..faefa750e 100644
--- a/reference_algorithms/target_setting_algorithms/librispeech_conformer_attention_temperature/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/librispeech_conformer_attention_temperature/tuning_search_space.json
@@ -1,27 +1,17 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.0007852999990476642
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.6994142393023162
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.9918636824608852
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 6000
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.07286322158086678
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.0007852999990476642]
+ },
+ "beta1": {
+ "feasible_points": [0.6994142393023162]
+ },
+ "beta2": {
+ "feasible_points": [0.9918636824608852]
+ },
+ "warmup_steps": {
+ "feasible_points": [6000]
+ },
+ "weight_decay": {
+ "feasible_points": [0.07286322158086678]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/librispeech_conformer_gelu/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_conformer_gelu/tuning_search_space.json
index ad200c01b..16ab02525 100644
--- a/reference_algorithms/target_setting_algorithms/librispeech_conformer_gelu/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/librispeech_conformer_gelu/tuning_search_space.json
@@ -1,27 +1,17 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.000590120167916659
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.737199286155609
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.05919391544031072
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 9999
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.14128519778326312
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.000590120167916659]
+ },
+ "beta1": {
+ "feasible_points": [0.737199286155609]
+ },
+ "beta2": {
+ "feasible_points": [0.05919391544031072]
+ },
+ "warmup_steps": {
+ "feasible_points": [9999]
+ },
+ "weight_decay": {
+ "feasible_points": [0.14128519778326312]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/librispeech_conformer_layernorm/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_conformer_layernorm/tuning_search_space.json
index 8297cf0ae..d596dcd2b 100644
--- a/reference_algorithms/target_setting_algorithms/librispeech_conformer_layernorm/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/librispeech_conformer_layernorm/tuning_search_space.json
@@ -1,27 +1,17 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.0014446807792420305
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.7427148812902895
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.8993064520764248
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 3000
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.06875136511682291
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.0014446807792420305]
+ },
+ "beta1": {
+ "feasible_points": [0.7427148812902895]
+ },
+ "beta2": {
+ "feasible_points": [0.8993064520764248]
+ },
+ "warmup_steps": {
+ "feasible_points": [3000]
+ },
+ "weight_decay": {
+ "feasible_points": [0.06875136511682291]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/librispeech_deepspeech_no_resnet/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/librispeech_deepspeech_no_resnet/tuning_search_space.json
index b31b711f7..dbcbecf78 100644
--- a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/librispeech_deepspeech_no_resnet/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/librispeech_deepspeech_no_resnet/tuning_search_space.json
@@ -1,27 +1,17 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.0035278622506232458
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.8192305396005781
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.495850879212151
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 6000
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.04339748256184769
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.0035278622506232458]
+ },
+ "beta1": {
+ "feasible_points": [0.8192305396005781]
+ },
+ "beta2": {
+ "feasible_points": [0.495850879212151]
+ },
+ "warmup_steps": {
+ "feasible_points": [6000]
+ },
+ "weight_decay": {
+ "feasible_points": [0.04339748256184769]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/librispeech_deepspeech_norm_and_spec_aug/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/librispeech_deepspeech_norm_and_spec_aug/tuning_search_space.json
index e20a2dae1..bbea133cb 100644
--- a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/librispeech_deepspeech_norm_and_spec_aug/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/librispeech_deepspeech_norm_and_spec_aug/tuning_search_space.json
@@ -1,27 +1,17 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.001308209823469072
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.9731333693827139
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.9981232922116359
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 6000
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.16375311233774334
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.001308209823469072]
+ },
+ "beta1": {
+ "feasible_points": [0.9731333693827139]
+ },
+ "beta2": {
+ "feasible_points": [0.9981232922116359]
+ },
+ "warmup_steps": {
+ "feasible_points": [6000]
+ },
+ "weight_decay": {
+ "feasible_points": [0.16375311233774334]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json
index 0a9bfb3cf..52fe59d84 100644
--- a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json
@@ -1,27 +1,17 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.004958460849689891
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.863744242567442
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.6291854735396584
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 720
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.1147386261512052
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.004958460849689891]
+ },
+ "beta1": {
+ "feasible_points": [0.863744242567442]
+ },
+ "beta2": {
+ "feasible_points": [0.6291854735396584]
+ },
+ "warmup_steps": {
+ "feasible_points": [720]
+ },
+ "weight_decay": {
+ "feasible_points": [0.1147386261512052]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_no_resnet/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_no_resnet/tuning_search_space.json
index e76a48325..898fc9e36 100644
--- a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_no_resnet/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_no_resnet/tuning_search_space.json
@@ -1,27 +1,17 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.0020162740358935045
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.9604907112078142
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.8765457000160508
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 3600
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.0006149579248633481
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.0020162740358935045]
+ },
+ "beta1": {
+ "feasible_points": [0.9604907112078142]
+ },
+ "beta2": {
+ "feasible_points": [0.8765457000160508]
+ },
+ "warmup_steps": {
+ "feasible_points": [3600]
+ },
+ "weight_decay": {
+ "feasible_points": [0.0006149579248633481]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_norm_and_spec_aug/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_norm_and_spec_aug/tuning_search_space.json
index 55f70f9fc..94f150ad3 100644
--- a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_norm_and_spec_aug/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_norm_and_spec_aug/tuning_search_space.json
@@ -1,27 +1,17 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.0014446807792420305
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.7427148812902895
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.8993064520764248
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 1800
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.06875136511682291
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.0014446807792420305]
+ },
+ "beta1": {
+ "feasible_points": [0.7427148812902895]
+ },
+ "beta2": {
+ "feasible_points": [0.8993064520764248]
+ },
+ "warmup_steps": {
+ "feasible_points": [1800]
+ },
+ "weight_decay": {
+ "feasible_points": [0.06875136511682291]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_tanh/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_tanh/tuning_search_space.json
index e5f906688..517e4a455 100644
--- a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_tanh/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech_tanh/tuning_search_space.json
@@ -1,27 +1,17 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.003604759885558324
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.9931094324430452
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.9976871843749077
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 720
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.120077307855989
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.003604759885558324]
+ },
+ "beta1": {
+ "feasible_points": [0.9931094324430452]
+ },
+ "beta2": {
+ "feasible_points": [0.9976871843749077]
+ },
+ "warmup_steps": {
+ "feasible_points": [720]
+ },
+ "weight_decay": {
+ "feasible_points": [0.120077307855989]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/ogbg/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/ogbg/tuning_search_space.json
index 0f365a183..266f0f3f5 100644
--- a/reference_algorithms/target_setting_algorithms/ogbg/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/ogbg/tuning_search_space.json
@@ -1,32 +1,20 @@
{
- "learning_rate": {
- "feasible_points": [
- 2.4917728606918423
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.9449369031171744
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 3000
- ]
- },
- "decay_steps_factor": {
- "feasible_points": [
- 0.861509027839639
- ]
- },
- "end_factor": {
- "feasible_points": [
- 0.001
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 1.2859640541025928e-7
- ]
- }
+ "learning_rate": {
+ "feasible_points": [2.4917728606918423]
+ },
+ "beta1": {
+ "feasible_points": [0.9449369031171744]
+ },
+ "warmup_steps": {
+ "feasible_points": [3000]
+ },
+ "decay_steps_factor": {
+ "feasible_points": [0.861509027839639]
+ },
+ "end_factor": {
+ "feasible_points": [0.001]
+ },
+ "weight_decay": {
+ "feasible_points": [1.2859640541025928e-7]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/ogbg_gelu/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/ogbg_gelu/tuning_search_space.json
index 0749f96d6..b5b4aad30 100644
--- a/reference_algorithms/target_setting_algorithms/ogbg_gelu/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/ogbg_gelu/tuning_search_space.json
@@ -1,27 +1,17 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.01897755400372091
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.9666072782043229
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.99681600289198
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 3000
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.015653883841116094
- ]
- }
-}
\ No newline at end of file
+ "learning_rate": {
+ "feasible_points": [0.01897755400372091]
+ },
+ "beta1": {
+ "feasible_points": [0.9666072782043229]
+ },
+ "beta2": {
+ "feasible_points": [0.99681600289198]
+ },
+ "warmup_steps": {
+ "feasible_points": [3000]
+ },
+ "weight_decay": {
+ "feasible_points": [0.015653883841116094]
+ }
+}
diff --git a/reference_algorithms/target_setting_algorithms/ogbg_model_size/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/ogbg_model_size/tuning_search_space.json
index d5af3b03e..b69114b88 100644
--- a/reference_algorithms/target_setting_algorithms/ogbg_model_size/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/ogbg_model_size/tuning_search_space.json
@@ -1,27 +1,17 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.001734480757979605
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.855609542347586
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.9834185656478605
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 3000
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.019843063335529494
- ]
- }
-}
\ No newline at end of file
+ "learning_rate": {
+ "feasible_points": [0.001734480757979605]
+ },
+ "beta1": {
+ "feasible_points": [0.855609542347586]
+ },
+ "beta2": {
+ "feasible_points": [0.9834185656478605]
+ },
+ "warmup_steps": {
+ "feasible_points": [3000]
+ },
+ "weight_decay": {
+ "feasible_points": [0.019843063335529494]
+ }
+}
diff --git a/reference_algorithms/target_setting_algorithms/ogbg_silu/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/ogbg_silu/tuning_search_space.json
index b9f83a5ed..e1512c02a 100644
--- a/reference_algorithms/target_setting_algorithms/ogbg_silu/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/ogbg_silu/tuning_search_space.json
@@ -1,27 +1,17 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.00027866530268792414
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.9919340993463499
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.9979843253162892
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 6000
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.00032418357325210813
- ]
- }
-}
\ No newline at end of file
+ "learning_rate": {
+ "feasible_points": [0.00027866530268792414]
+ },
+ "beta1": {
+ "feasible_points": [0.9919340993463499]
+ },
+ "beta2": {
+ "feasible_points": [0.9979843253162892]
+ },
+ "warmup_steps": {
+ "feasible_points": [6000]
+ },
+ "weight_decay": {
+ "feasible_points": [0.00032418357325210813]
+ }
+}
diff --git a/reference_algorithms/target_setting_algorithms/pytorch_adamw.py b/reference_algorithms/target_setting_algorithms/pytorch_adamw.py
index f2474a706..14e8155d4 100644
--- a/reference_algorithms/target_setting_algorithms/pytorch_adamw.py
+++ b/reference_algorithms/target_setting_algorithms/pytorch_adamw.py
@@ -4,15 +4,15 @@
from algoperf import spec
from reference_algorithms.target_setting_algorithms import cosine_warmup
-from reference_algorithms.target_setting_algorithms.data_selection import (
+from reference_algorithms.target_setting_algorithms.data_selection import ( # noqa: F401
data_selection,
-) # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.get_batch_size import (
+)
+from reference_algorithms.target_setting_algorithms.get_batch_size import ( # noqa: F401
get_batch_size,
-) # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.pytorch_submission_base import (
+)
+from reference_algorithms.target_setting_algorithms.pytorch_submission_base import ( # noqa: F401
update_params,
-) # pylint: disable=unused-import
+)
def init_optimizer_state(
diff --git a/reference_algorithms/target_setting_algorithms/pytorch_momentum.py b/reference_algorithms/target_setting_algorithms/pytorch_momentum.py
index 030939de5..a23b835eb 100644
--- a/reference_algorithms/target_setting_algorithms/pytorch_momentum.py
+++ b/reference_algorithms/target_setting_algorithms/pytorch_momentum.py
@@ -4,18 +4,18 @@
from torch.optim.lr_scheduler import LambdaLR
from algoperf import spec
-from reference_algorithms.target_setting_algorithms.data_selection import (
+from reference_algorithms.target_setting_algorithms.data_selection import ( # noqa: F401
data_selection,
-) # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.get_batch_size import (
+)
+from reference_algorithms.target_setting_algorithms.get_batch_size import ( # noqa: F401
get_batch_size,
-) # pylint: disable=unused-import
+)
from reference_algorithms.target_setting_algorithms.jax_momentum import (
create_lr_schedule_fn,
)
-from reference_algorithms.target_setting_algorithms.pytorch_submission_base import (
+from reference_algorithms.target_setting_algorithms.pytorch_submission_base import ( # noqa: F401
update_params,
-) # pylint: disable=unused-import
+)
def init_optimizer_state(
diff --git a/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py b/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py
index ceeebda6d..d301a233f 100644
--- a/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py
+++ b/reference_algorithms/target_setting_algorithms/pytorch_nadamw.py
@@ -8,15 +8,15 @@
from algoperf import spec
from reference_algorithms.target_setting_algorithms import cosine_warmup
-from reference_algorithms.target_setting_algorithms.data_selection import (
+from reference_algorithms.target_setting_algorithms.data_selection import ( # noqa: F401
data_selection,
-) # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.get_batch_size import (
+)
+from reference_algorithms.target_setting_algorithms.get_batch_size import ( # noqa: F401
get_batch_size,
-) # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.pytorch_submission_base import (
+)
+from reference_algorithms.target_setting_algorithms.pytorch_submission_base import ( # noqa: F401
update_params,
-) # pylint: disable=unused-import
+)
# Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py
diff --git a/reference_algorithms/target_setting_algorithms/pytorch_nesterov.py b/reference_algorithms/target_setting_algorithms/pytorch_nesterov.py
index ddbcaefdb..3a6294a28 100644
--- a/reference_algorithms/target_setting_algorithms/pytorch_nesterov.py
+++ b/reference_algorithms/target_setting_algorithms/pytorch_nesterov.py
@@ -4,18 +4,18 @@
from torch.optim.lr_scheduler import LambdaLR
from algoperf import spec
-from reference_algorithms.target_setting_algorithms.data_selection import (
+from reference_algorithms.target_setting_algorithms.data_selection import ( # noqa: F401
data_selection,
-) # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.get_batch_size import (
+)
+from reference_algorithms.target_setting_algorithms.get_batch_size import ( # noqa: F401
get_batch_size,
-) # pylint: disable=unused-import
-from reference_algorithms.target_setting_algorithms.jax_nesterov import (
+)
+from reference_algorithms.target_setting_algorithms.jax_momentum import (
create_lr_schedule_fn,
)
-from reference_algorithms.target_setting_algorithms.pytorch_submission_base import (
+from reference_algorithms.target_setting_algorithms.pytorch_submission_base import ( # noqa: F401
update_params,
-) # pylint: disable=unused-import
+)
def init_optimizer_state(
diff --git a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py
index 36f736a6b..0bef4548f 100644
--- a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py
+++ b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py
@@ -2,9 +2,9 @@
from typing import Any, Dict, List, Optional, Tuple
-from absl import logging
import torch
import torch.distributed.nn as dist_nn
+from absl import logging
from algoperf import spec
from algoperf.pytorch_utils import pytorch_setup
diff --git a/reference_algorithms/target_setting_algorithms/wmt/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/wmt/tuning_search_space.json
index f0ef45daa..aee18d976 100644
--- a/reference_algorithms/target_setting_algorithms/wmt/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/wmt/tuning_search_space.json
@@ -1,32 +1,20 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.0017486387539278373
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.9326607383586145
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.9955159689799007
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 1999
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.08121616522670176
- ]
- },
- "label_smoothing": {
- "feasible_points": [
- 0.0
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.0017486387539278373]
+ },
+ "beta1": {
+ "feasible_points": [0.9326607383586145]
+ },
+ "beta2": {
+ "feasible_points": [0.9955159689799007]
+ },
+ "warmup_steps": {
+ "feasible_points": [1999]
+ },
+ "weight_decay": {
+ "feasible_points": [0.08121616522670176]
+ },
+ "label_smoothing": {
+ "feasible_points": [0.0]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/wmt_attention_temp/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/wmt_attention_temp/tuning_search_space.json
index 266cdedbb..e1ce2229f 100644
--- a/reference_algorithms/target_setting_algorithms/wmt_attention_temp/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/wmt_attention_temp/tuning_search_space.json
@@ -1,32 +1,20 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.000590120167916659
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.737199286155609
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.05919391544031072
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 9999
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.14128519778326312
- ]
- },
- "label_smoothing": {
- "feasible_points": [
- 0.0
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.000590120167916659]
+ },
+ "beta1": {
+ "feasible_points": [0.737199286155609]
+ },
+ "beta2": {
+ "feasible_points": [0.05919391544031072]
+ },
+ "warmup_steps": {
+ "feasible_points": [9999]
+ },
+ "weight_decay": {
+ "feasible_points": [0.14128519778326312]
+ },
+ "label_smoothing": {
+ "feasible_points": [0.0]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/wmt_glu_tanh/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/wmt_glu_tanh/tuning_search_space.json
index d288d9a49..0ed0f832a 100644
--- a/reference_algorithms/target_setting_algorithms/wmt_glu_tanh/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/wmt_glu_tanh/tuning_search_space.json
@@ -1,32 +1,20 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.000872041489644454
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.45562164405092065
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.9982167124443476
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 4999
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.01536114562763022
- ]
- },
- "label_smoothing": {
- "feasible_points": [
- 0.1
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.000872041489644454]
+ },
+ "beta1": {
+ "feasible_points": [0.45562164405092065]
+ },
+ "beta2": {
+ "feasible_points": [0.9982167124443476]
+ },
+ "warmup_steps": {
+ "feasible_points": [4999]
+ },
+ "weight_decay": {
+ "feasible_points": [0.01536114562763022]
+ },
+ "label_smoothing": {
+ "feasible_points": [0.1]
+ }
}
diff --git a/reference_algorithms/target_setting_algorithms/wmt_post_ln/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/wmt_post_ln/tuning_search_space.json
index 1327bcb38..d1b3045c7 100644
--- a/reference_algorithms/target_setting_algorithms/wmt_post_ln/tuning_search_space.json
+++ b/reference_algorithms/target_setting_algorithms/wmt_post_ln/tuning_search_space.json
@@ -1,32 +1,20 @@
{
- "learning_rate": {
- "feasible_points": [
- 0.0003477912008450351
- ]
- },
- "beta1": {
- "feasible_points": [
- 0.9936632117510711
- ]
- },
- "beta2": {
- "feasible_points": [
- 0.9967873550453692
- ]
- },
- "warmup_steps": {
- "feasible_points": [
- 9999
- ]
- },
- "weight_decay": {
- "feasible_points": [
- 0.04120183162940475
- ]
- },
- "label_smoothing": {
- "feasible_points": [
- 0.0
- ]
- }
+ "learning_rate": {
+ "feasible_points": [0.0003477912008450351]
+ },
+ "beta1": {
+ "feasible_points": [0.9936632117510711]
+ },
+ "beta2": {
+ "feasible_points": [0.9967873550453692]
+ },
+ "warmup_steps": {
+ "feasible_points": [9999]
+ },
+ "weight_decay": {
+ "feasible_points": [0.04120183162940475]
+ },
+ "label_smoothing": {
+ "feasible_points": [0.0]
+ }
}
From 566b6d9ee6f0489fc1e27e567f01bf9d23689fc1 Mon Sep 17 00:00:00 2001
From: Frank Schneider
Date: Mon, 23 Jun 2025 11:16:43 +0200
Subject: [PATCH 116/123] Lint algoperf/
---
algoperf/checkpoint_utils.py | 6 +++---
algoperf/data_utils.py | 6 ++----
algoperf/logger_utils.py | 12 ++++++------
algoperf/profiler.py | 4 ++--
algoperf/pytorch_utils.py | 5 ++---
algoperf/random_utils.py | 3 +--
.../workloads/cifar/cifar_jax/input_pipeline.py | 2 +-
algoperf/workloads/cifar/cifar_jax/models.py | 2 +-
algoperf/workloads/cifar/cifar_pytorch/models.py | 4 +---
.../workloads/cifar/cifar_pytorch/workload.py | 5 +----
algoperf/workloads/cifar/workload.py | 2 +-
algoperf/workloads/criteo1tb/workload.py | 2 +-
.../imagenet_jax/input_pipeline.py | 5 ++---
.../imagenet_resnet/imagenet_jax/models.py | 2 +-
.../imagenet_resnet/imagenet_jax/randaugment.py | 4 ----
.../imagenet_resnet/imagenet_pytorch/models.py | 3 +--
.../imagenet_pytorch/randaugment.py | 2 +-
algoperf/workloads/imagenet_resnet/imagenet_v2.py | 3 +--
.../librispeech_conformer/input_pipeline.py | 2 +-
.../librispeech_jax/librispeech_preprocessor.py | 4 ++--
.../librispeech_pytorch/preprocessor.py | 4 ++--
.../workloads/librispeech_conformer/metrics.py | 2 +-
.../librispeech_pytorch/models.py | 15 +++++++--------
.../workloads/mnist/mnist_pytorch/workload.py | 8 +++-----
algoperf/workloads/mnist/workload.py | 5 ++---
algoperf/workloads/ogbg/metrics.py | 4 ++--
algoperf/workloads/ogbg/workload.py | 3 +--
algoperf/workloads/wmt/bleu.py | 10 ++++------
algoperf/workloads/wmt/input_pipeline.py | 4 ++--
algoperf/workloads/wmt/tokenizer.py | 4 ++--
algoperf/workloads/wmt/wmt_jax/decode.py | 2 +-
31 files changed, 58 insertions(+), 81 deletions(-)
diff --git a/algoperf/checkpoint_utils.py b/algoperf/checkpoint_utils.py
index 577baaa34..f8cc40599 100644
--- a/algoperf/checkpoint_utils.py
+++ b/algoperf/checkpoint_utils.py
@@ -7,14 +7,14 @@
import os
from typing import Sequence, Tuple
+import jax
+import numpy as np
+import torch
from absl import logging
from flax import jax_utils
from flax.training import checkpoints as flax_checkpoints
from flax.training.checkpoints import latest_checkpoint
-import jax
-import numpy as np
from tensorflow.io import gfile # pytype: disable=import-error
-import torch
from algoperf import spec
from algoperf.pytorch_utils import pytorch_setup
diff --git a/algoperf/data_utils.py b/algoperf/data_utils.py
index 919ccd125..f08d9d2db 100644
--- a/algoperf/data_utils.py
+++ b/algoperf/data_utils.py
@@ -7,9 +7,7 @@
import torch
import torch.distributed as dist
import torch.nn.functional as F
-from torch.utils.data import DataLoader
-from torch.utils.data import DistributedSampler
-from torch.utils.data import Sampler
+from torch.utils.data import DataLoader, DistributedSampler, Sampler
from algoperf import spec
@@ -266,7 +264,7 @@ def prefetched_loader(self) -> Iterable[Tuple[spec.Tensor, spec.Tensor]]:
next_targets = next_targets.to(self.device, non_blocking=True)
if not first:
- yield inputs, targets
+ yield inputs, targets # noqa: F821
else:
first = False
diff --git a/algoperf/logger_utils.py b/algoperf/logger_utils.py
index 3c4898142..17eea74a6 100644
--- a/algoperf/logger_utils.py
+++ b/algoperf/logger_utils.py
@@ -11,12 +11,12 @@
import sys
from typing import Any, Dict, Optional
-from absl import flags
-from clu import metric_writers
import GPUtil
import pandas as pd
import psutil
import torch.distributed as dist
+from absl import flags
+from clu import metric_writers
from algoperf import spec
from algoperf.pytorch_utils import pytorch_setup
@@ -198,7 +198,7 @@ def _get_system_hardware_info() -> Dict:
try:
system_hardware_info['cpu_model_name'] = _get_cpu_model_name()
system_hardware_info['cpu_count'] = psutil.cpu_count()
- except: # pylint: disable=bare-except
+ except: # noqa: E722
logging.info('Unable to record cpu information. Continuing without it.')
gpus = GPUtil.getGPUs()
@@ -207,7 +207,7 @@ def _get_system_hardware_info() -> Dict:
system_hardware_info['gpu_model_name'] = gpus[0].name
system_hardware_info['gpu_count'] = len(gpus)
system_hardware_info['gpu_driver'] = gpus[0].driver
- except: # pylint: disable=bare-except
+ except: # noqa: E722
logging.info('Unable to record gpu information. Continuing without it.')
return system_hardware_info
@@ -232,7 +232,7 @@ def _get_system_software_info() -> Dict:
system_software_info['git_commit_hash'] = _get_git_commit_hash()
# Note: do not store git repo url as it may be sensitive or contain a
# secret.
- except: # pylint: disable=bare-except
+ except: # noqa: E722
logging.info('Unable to record git information. Continuing without it.')
return system_software_info
@@ -279,7 +279,7 @@ def _get_workload_properties(workload: spec.Workload) -> Dict:
for key in keys:
try:
attr = getattr(workload, key)
- except: # pylint: disable=bare-except
+ except: # noqa: E722
logging.info(
f'Unable to record workload.{key} information. Continuing without it.'
)
diff --git a/algoperf/profiler.py b/algoperf/profiler.py
index 0e791d3a8..534a5ccfb 100644
--- a/algoperf/profiler.py
+++ b/algoperf/profiler.py
@@ -4,10 +4,10 @@
https://github.com/Lightning-AI/lightning/tree/master/src/pytorch_lightning/profilers.
"""
-from collections import defaultdict
-from contextlib import contextmanager
import os
import time
+from collections import defaultdict
+from contextlib import contextmanager
from typing import Dict, Generator, List, Optional, Tuple
import numpy as np
diff --git a/algoperf/pytorch_utils.py b/algoperf/pytorch_utils.py
index 429e4d1e2..af09e67fc 100644
--- a/algoperf/pytorch_utils.py
+++ b/algoperf/pytorch_utils.py
@@ -1,14 +1,13 @@
import os
from typing import Tuple
-from absl import logging
import jax
import tensorflow as tf
import torch
-from torch import nn
-from torch import Tensor
import torch.distributed as dist
import torch.nn.functional as F
+from absl import logging
+from torch import Tensor, nn
from algoperf import spec
from algoperf.profiler import Profiler
diff --git a/algoperf/random_utils.py b/algoperf/random_utils.py
index 41f4b6b41..1dc773e80 100644
--- a/algoperf/random_utils.py
+++ b/algoperf/random_utils.py
@@ -2,9 +2,8 @@
from typing import Any, List, Union
-from absl import flags
-from absl import logging
import numpy as np
+from absl import flags, logging
try:
import jax.random as jax_rng
diff --git a/algoperf/workloads/cifar/cifar_jax/input_pipeline.py b/algoperf/workloads/cifar/cifar_jax/input_pipeline.py
index 3d831c4af..7fbc95bc6 100644
--- a/algoperf/workloads/cifar/cifar_jax/input_pipeline.py
+++ b/algoperf/workloads/cifar/cifar_jax/input_pipeline.py
@@ -8,10 +8,10 @@
import functools
from typing import Dict, Iterator, Tuple
-from flax import jax_utils
import jax
import tensorflow as tf
import tensorflow_datasets as tfds
+from flax import jax_utils
from algoperf import spec
from algoperf.data_utils import shard_and_maybe_pad_np
diff --git a/algoperf/workloads/cifar/cifar_jax/models.py b/algoperf/workloads/cifar/cifar_jax/models.py
index 8d034796f..95238c997 100644
--- a/algoperf/workloads/cifar/cifar_jax/models.py
+++ b/algoperf/workloads/cifar/cifar_jax/models.py
@@ -7,8 +7,8 @@
import functools
from typing import Any, Callable, Tuple
-from flax import linen as nn
import jax.numpy as jnp
+from flax import linen as nn
from algoperf import spec
from algoperf.workloads.imagenet_resnet.imagenet_jax.models import ResNetBlock
diff --git a/algoperf/workloads/cifar/cifar_pytorch/models.py b/algoperf/workloads/cifar/cifar_pytorch/models.py
index 6beef89e6..0e08f5c5a 100644
--- a/algoperf/workloads/cifar/cifar_pytorch/models.py
+++ b/algoperf/workloads/cifar/cifar_pytorch/models.py
@@ -14,11 +14,9 @@
from algoperf.init_utils import pytorch_default_init
from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import (
BasicBlock,
-)
-from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import (
Bottleneck,
+ conv1x1,
)
-from algoperf.workloads.imagenet_resnet.imagenet_pytorch.models import conv1x1
class ResNet(nn.Module):
diff --git a/algoperf/workloads/cifar/cifar_pytorch/workload.py b/algoperf/workloads/cifar/cifar_pytorch/workload.py
index d7e858226..f1189bebc 100644
--- a/algoperf/workloads/cifar/cifar_pytorch/workload.py
+++ b/algoperf/workloads/cifar/cifar_pytorch/workload.py
@@ -12,10 +12,7 @@
from torchvision import transforms
from torchvision.datasets import CIFAR10
-from algoperf import data_utils
-from algoperf import param_utils
-from algoperf import pytorch_utils
-from algoperf import spec
+from algoperf import data_utils, param_utils, pytorch_utils, spec
from algoperf.workloads.cifar.cifar_pytorch.models import resnet18
from algoperf.workloads.cifar.workload import BaseCifarWorkload
diff --git a/algoperf/workloads/cifar/workload.py b/algoperf/workloads/cifar/workload.py
index 61880fbfa..31636807c 100644
--- a/algoperf/workloads/cifar/workload.py
+++ b/algoperf/workloads/cifar/workload.py
@@ -7,9 +7,9 @@
import jax
import torch
+import algoperf.random_utils as prng
from algoperf import spec
from algoperf.pytorch_utils import pytorch_setup
-import algoperf.random_utils as prng
USE_PYTORCH_DDP, _, _, _ = pytorch_setup()
diff --git a/algoperf/workloads/criteo1tb/workload.py b/algoperf/workloads/criteo1tb/workload.py
index 9fb819203..2cb7e5450 100644
--- a/algoperf/workloads/criteo1tb/workload.py
+++ b/algoperf/workloads/criteo1tb/workload.py
@@ -4,8 +4,8 @@
import os
from typing import Dict, Iterator, Optional, Tuple
-from absl import flags
import torch.distributed as dist
+from absl import flags
from algoperf import spec
from algoperf.workloads.criteo1tb import input_pipeline
diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py
index fc42f4a5b..f782e50a1 100644
--- a/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py
+++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/input_pipeline.py
@@ -7,13 +7,12 @@
import functools
from typing import Dict, Iterator, Tuple
-from flax import jax_utils
import jax
import tensorflow as tf
import tensorflow_datasets as tfds
+from flax import jax_utils
-from algoperf import data_utils
-from algoperf import spec
+from algoperf import data_utils, spec
from algoperf.workloads.imagenet_resnet.imagenet_jax import randaugment
TFDS_SPLIT_NAME = {
diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py
index 1f3911708..84ad4fe21 100644
--- a/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py
+++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/models.py
@@ -7,8 +7,8 @@
import functools
from typing import Any, Callable, Optional, Tuple
-from flax import linen as nn
import jax.numpy as jnp
+from flax import linen as nn
from algoperf import spec
diff --git a/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py
index 03b36e03d..87e218a0c 100644
--- a/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py
+++ b/algoperf/workloads/imagenet_resnet/imagenet_jax/randaugment.py
@@ -11,11 +11,7 @@
from algoperf.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import (
rotate_img,
-)
-from algoperf.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import (
transform,
-)
-from algoperf.workloads.imagenet_resnet.imagenet_jax.custom_tf_addons import (
translate,
)
diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/models.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/models.py
index ab3fc4a37..c980faa06 100644
--- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/models.py
+++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/models.py
@@ -8,8 +8,7 @@
from typing import Any, Callable, List, Optional, Type, Union
import torch
-from torch import nn
-from torch import Tensor
+from torch import Tensor, nn
from algoperf import spec
from algoperf.init_utils import pytorch_default_init
diff --git a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/randaugment.py b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/randaugment.py
index 28ce00650..1c5c0d952 100644
--- a/algoperf/workloads/imagenet_resnet/imagenet_pytorch/randaugment.py
+++ b/algoperf/workloads/imagenet_resnet/imagenet_pytorch/randaugment.py
@@ -11,8 +11,8 @@
import PIL
import torch
from torch import Tensor
-from torchvision.transforms import functional as F
from torchvision.transforms import InterpolationMode
+from torchvision.transforms import functional as F
from algoperf import spec
diff --git a/algoperf/workloads/imagenet_resnet/imagenet_v2.py b/algoperf/workloads/imagenet_resnet/imagenet_v2.py
index 6ffb73367..7a8e38f02 100644
--- a/algoperf/workloads/imagenet_resnet/imagenet_v2.py
+++ b/algoperf/workloads/imagenet_resnet/imagenet_v2.py
@@ -8,8 +8,7 @@
import tensorflow_datasets as tfds
-from algoperf import data_utils
-from algoperf import spec
+from algoperf import data_utils, spec
from algoperf.workloads.imagenet_resnet.imagenet_jax import input_pipeline
diff --git a/algoperf/workloads/librispeech_conformer/input_pipeline.py b/algoperf/workloads/librispeech_conformer/input_pipeline.py
index 570db07b3..23ce8e3b7 100644
--- a/algoperf/workloads/librispeech_conformer/input_pipeline.py
+++ b/algoperf/workloads/librispeech_conformer/input_pipeline.py
@@ -4,9 +4,9 @@
import csv
-from absl import logging
import numpy as np
import torch
+from absl import logging
class LibriSpeechDataset(torch.utils.data.Dataset):
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_jax/librispeech_preprocessor.py b/algoperf/workloads/librispeech_conformer/librispeech_jax/librispeech_preprocessor.py
index bd36b1bb9..531e68a45 100644
--- a/algoperf/workloads/librispeech_conformer/librispeech_jax/librispeech_preprocessor.py
+++ b/algoperf/workloads/librispeech_conformer/librispeech_jax/librispeech_preprocessor.py
@@ -10,11 +10,11 @@
from typing import Any, Optional, Union
-from flax import linen as nn
-from flax import struct
import jax
import jax.numpy as jnp
import numpy as np
+from flax import linen as nn
+from flax import struct
# mel spectrum constants.
_MEL_BREAK_FREQUENCY_HERTZ = 700.0
diff --git a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/preprocessor.py b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/preprocessor.py
index f8c1bd0d2..58dd837dc 100644
--- a/algoperf/workloads/librispeech_conformer/librispeech_pytorch/preprocessor.py
+++ b/algoperf/workloads/librispeech_conformer/librispeech_pytorch/preprocessor.py
@@ -2,14 +2,14 @@
https://github.com/google/init2winit/blob/master/init2winit/model_lib/librispeech_preprocessor.py.
"""
-from dataclasses import dataclass
import math
+from dataclasses import dataclass
from typing import Any, Optional, Union
import numpy as np
import torch
-from torch import nn
import torch.nn.functional as F
+from torch import nn
# mel spectrum constants.
_MEL_BREAK_FREQUENCY_HERTZ = 700.0
diff --git a/algoperf/workloads/librispeech_conformer/metrics.py b/algoperf/workloads/librispeech_conformer/metrics.py
index d5c826575..7dd6a11dc 100644
--- a/algoperf/workloads/librispeech_conformer/metrics.py
+++ b/algoperf/workloads/librispeech_conformer/metrics.py
@@ -1,8 +1,8 @@
-from clu import metrics
import flax
import numpy as np
import tensorflow as tf
import tensorflow_text as tftxt
+from clu import metrics
gfile = tf.io.gfile
diff --git a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py
index ddb7b5c37..aab75da63 100644
--- a/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py
+++ b/algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/models.py
@@ -2,14 +2,14 @@
https://github.com/google/init2winit/blob/master/init2winit/model_lib/conformer.py.
"""
-from dataclasses import dataclass
import os
+from dataclasses import dataclass
from typing import Tuple
import torch
-from torch import nn
import torch.distributed.nn as dist_nn
import torch.nn.functional as F
+from torch import nn
from algoperf.workloads.librispeech_conformer.librispeech_pytorch import (
preprocessor,
@@ -88,7 +88,6 @@ def __init__(self, config: DeepspeechConfig):
self.lin = nn.LazyLinear(out_features=self.encoder_dim, bias=True)
def forward(self, inputs, input_paddings, dropout_rate):
-
output_paddings = input_paddings
outputs = inputs[:, None, :, :]
@@ -209,7 +208,6 @@ def __init__(self, config: DeepspeechConfig):
self.lin = nn.LazyLinear(out_features=config.encoder_dim, bias=True)
def forward(self, inputs, input_paddings, dropout_rate):
-
padding_mask = (1 - input_paddings)[:, :, None]
if self.config.layernorm_everywhere:
inputs = self.normalization_layer(inputs)
@@ -381,9 +379,9 @@ def forward(self, inputs, input_paddings, dropout_rate=DROPOUT_RATE):
outputs, output_paddings = self.preprocessor(outputs, output_paddings)
if self.training and self.config.use_specaug:
outputs, output_paddings = self.specaug(outputs, output_paddings)
- outputs, output_paddings = self.subsample(outputs,
- output_paddings,
- dropout_rate)
+ outputs, output_paddings = self.subsample(
+ outputs, output_paddings, dropout_rate
+ )
for idx in range(self.config.num_lstm_layers):
if self.config.enable_residual_connections:
outputs = outputs + self.lstms[idx](outputs, output_paddings)
@@ -393,7 +391,8 @@ def forward(self, inputs, input_paddings, dropout_rate=DROPOUT_RATE):
for idx in range(self.config.num_ffn_layers):
if self.config.enable_residual_connections:
outputs = outputs + self.ffns[idx](
- outputs, output_paddings, dropout_rate)
+ outputs, output_paddings, dropout_rate
+ )
else:
outputs = self.ffns[idx](outputs, output_paddings, dropout_rate)
diff --git a/algoperf/workloads/mnist/mnist_pytorch/workload.py b/algoperf/workloads/mnist/mnist_pytorch/workload.py
index ca861d551..b58898703 100644
--- a/algoperf/workloads/mnist/mnist_pytorch/workload.py
+++ b/algoperf/workloads/mnist/mnist_pytorch/workload.py
@@ -1,18 +1,16 @@
"""MNIST workload implemented in PyTorch."""
-from collections import OrderedDict
import contextlib
+from collections import OrderedDict
from typing import Any, Dict, Iterator, Optional, Tuple
import torch
-from torch import nn
import torch.distributed as dist
import torch.nn.functional as F
+from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP
-from algoperf import init_utils
-from algoperf import param_utils
-from algoperf import spec
+from algoperf import init_utils, param_utils, spec
from algoperf.pytorch_utils import pytorch_setup
from algoperf.workloads.mnist.workload import BaseMnistWorkload
diff --git a/algoperf/workloads/mnist/workload.py b/algoperf/workloads/mnist/workload.py
index 20aa975ae..38006b9ac 100644
--- a/algoperf/workloads/mnist/workload.py
+++ b/algoperf/workloads/mnist/workload.py
@@ -10,10 +10,9 @@
import tensorflow_datasets as tfds
import torch
-from algoperf import data_utils
-from algoperf import spec
-from algoperf.pytorch_utils import pytorch_setup
import algoperf.random_utils as prng
+from algoperf import data_utils, spec
+from algoperf.pytorch_utils import pytorch_setup
USE_PYTORCH_DDP, _, _, _ = pytorch_setup()
diff --git a/algoperf/workloads/ogbg/metrics.py b/algoperf/workloads/ogbg/metrics.py
index 668501788..3e7825219 100644
--- a/algoperf/workloads/ogbg/metrics.py
+++ b/algoperf/workloads/ogbg/metrics.py
@@ -2,14 +2,14 @@
# https://github.com/google/flax/blob/main/examples/ogbg_molpcba/train.py
from typing import Any
-from clu import metrics
import flax
import jax
import jax.numpy as jnp
import numpy as np
-from sklearn.metrics import average_precision_score
import torch
import torch.distributed as dist
+from clu import metrics
+from sklearn.metrics import average_precision_score
from algoperf.pytorch_utils import pytorch_setup
diff --git a/algoperf/workloads/ogbg/workload.py b/algoperf/workloads/ogbg/workload.py
index 1d182fed5..8717e46d6 100644
--- a/algoperf/workloads/ogbg/workload.py
+++ b/algoperf/workloads/ogbg/workload.py
@@ -9,8 +9,7 @@
from algoperf import random_utils as prng
from algoperf import spec
-from algoperf.workloads.ogbg import input_pipeline
-from algoperf.workloads.ogbg import metrics
+from algoperf.workloads.ogbg import input_pipeline, metrics
class BaseOgbgWorkload(spec.Workload):
diff --git a/algoperf/workloads/wmt/bleu.py b/algoperf/workloads/wmt/bleu.py
index e0064ba51..6e29b1b83 100644
--- a/algoperf/workloads/wmt/bleu.py
+++ b/algoperf/workloads/wmt/bleu.py
@@ -5,19 +5,17 @@
https://github.com/mjpost/sacrebleu/blob/v1.3.1/sacrebleu.py.
"""
-from collections import Counter
-from collections import namedtuple
-from itertools import zip_longest
-import logging
import math
import re
import sys
-from typing import List, Sequence
import unicodedata
+from collections import Counter, namedtuple
+from itertools import zip_longest
+from typing import List, Sequence
-from absl import logging
import torch
import torch.distributed as dist
+from absl import logging
from algoperf.pytorch_utils import pytorch_setup
diff --git a/algoperf/workloads/wmt/input_pipeline.py b/algoperf/workloads/wmt/input_pipeline.py
index 1df1dfc55..3d184cd78 100644
--- a/algoperf/workloads/wmt/input_pipeline.py
+++ b/algoperf/workloads/wmt/input_pipeline.py
@@ -238,8 +238,8 @@ def preprocess_wmt_data(
def length_filter(max_len):
def filter_fn(x):
source, target = x['inputs'], x['targets']
- l = tf.maximum(tf.shape(source)[0], tf.shape(target)[0])
- return tf.less(l, max_len + 1)
+ length = tf.maximum(tf.shape(source)[0], tf.shape(target)[0])
+ return tf.less(length, max_len + 1)
return filter_fn
diff --git a/algoperf/workloads/wmt/tokenizer.py b/algoperf/workloads/wmt/tokenizer.py
index d94cab808..273e11dfa 100644
--- a/algoperf/workloads/wmt/tokenizer.py
+++ b/algoperf/workloads/wmt/tokenizer.py
@@ -9,11 +9,11 @@
import time
from typing import Any, Dict, Iterable, Tuple
-from absl import logging
import jax
-from sentencepiece import SentencePieceTrainer
import tensorflow as tf
import tensorflow_text as tftxt
+from absl import logging
+from sentencepiece import SentencePieceTrainer
Features = Dict[str, tf.Tensor]
diff --git a/algoperf/workloads/wmt/wmt_jax/decode.py b/algoperf/workloads/wmt/wmt_jax/decode.py
index b5f5f1099..196d9175e 100644
--- a/algoperf/workloads/wmt/wmt_jax/decode.py
+++ b/algoperf/workloads/wmt/wmt_jax/decode.py
@@ -7,9 +7,9 @@
import flax
import jax
-from jax import lax
import jax.numpy as jnp
import numpy as np
+from jax import lax
# Constants
# We assume the default End-of-Sentence token id is 2 (SentencePiece).
From e8466489a87b631f2f73f79a4e50c0d8ffef7544 Mon Sep 17 00:00:00 2001
From: Frank Schneider
Date: Mon, 23 Jun 2025 12:41:28 +0200
Subject: [PATCH 117/123] Remove unnecessary isort=off commands
---
.../external_tuning/jax_nadamw_full_budget.py | 4 ----
.../external_tuning/jax_nadamw_target_setting.py | 4 ----
.../self_tuning/jax_nadamw_full_budget.py | 4 ----
.../self_tuning/jax_nadamw_target_setting.py | 4 ----
reference_algorithms/paper_baselines/nadamw/jax/submission.py | 4 ----
5 files changed, 20 deletions(-)
diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py
index ed721d167..62161b3d5 100644
--- a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py
+++ b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py
@@ -1,9 +1,6 @@
"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax."""
import functools
-
-# isort: off
-# We have to turn off isort here to resolve a conflict between isort and yapf.
from typing import (
Any,
Callable,
@@ -15,7 +12,6 @@
Tuple,
Union,
)
-# isort: on
import chex
import jax
diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py
index fdcdd5348..9752aef33 100644
--- a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py
+++ b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py
@@ -1,9 +1,6 @@
"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax."""
import functools
-
-# isort: off
-# We have to turn off isort here to resolve a conflict between isort and yapf.
from typing import (
Any,
Callable,
@@ -15,7 +12,6 @@
Tuple,
Union,
)
-# isort: on
import chex
import jax
diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py
index 0b4e5aba3..f61a7bdcd 100644
--- a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py
+++ b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py
@@ -1,9 +1,6 @@
"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax."""
import functools
-
-# isort: off
-# We have to turn off isort here to resolve a conflict between isort and yapf.
from typing import (
Any,
Callable,
@@ -15,7 +12,6 @@
Tuple,
Union,
)
-# isort: on
import chex
import jax
diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py
index d3efc3a55..130ebdabe 100644
--- a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py
+++ b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py
@@ -1,9 +1,6 @@
"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax."""
import functools
-
-# isort: off
-# We have to turn off isort here to resolve a conflict between isort and yapf.
from typing import (
Any,
Callable,
@@ -15,7 +12,6 @@
Tuple,
Union,
)
-# isort: on
import chex
import jax
diff --git a/reference_algorithms/paper_baselines/nadamw/jax/submission.py b/reference_algorithms/paper_baselines/nadamw/jax/submission.py
index ed721d167..62161b3d5 100644
--- a/reference_algorithms/paper_baselines/nadamw/jax/submission.py
+++ b/reference_algorithms/paper_baselines/nadamw/jax/submission.py
@@ -1,9 +1,6 @@
"""Submission file for an NAdamW optimizer with warmup+cosine LR in Jax."""
import functools
-
-# isort: off
-# We have to turn off isort here to resolve a conflict between isort and yapf.
from typing import (
Any,
Callable,
@@ -15,7 +12,6 @@
Tuple,
Union,
)
-# isort: on
import chex
import jax
From 3e425e07fa89eb25b7be839fa649db03915ea462 Mon Sep 17 00:00:00 2001
From: Frank Schneider
Date: Mon, 23 Jun 2025 13:06:29 +0200
Subject: [PATCH 118/123] Update Ruff linting rules in pyproject.toml to
include additional options for future use
---
pyproject.toml | 32 ++++++++++++++++++++++++++++++--
1 file changed, 30 insertions(+), 2 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 2db40f61f..976857e8c 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -121,8 +121,36 @@ target-version = "py311"
quote-style = "single"
[tool.ruff.lint]
-extend-select = ["I"]
-# Could add (in the future): "E", "F", "UP", "B", "SIM", "PL"
+# Could add the commented out rules in the future:
+extend-select = [
+ "BLE", # disallow catch-all exceptions
+ "COM", # enforce trailing comma rules
+ "F", # Pyflakes rules
+ "FA", # Enforce from __future__ import annotations
+ "I", # Isort rules
+ "ICN", # Use common import conventions
+ "TID", # Some good import practices
+ # "A", # flake8-builtins: detect shadowed builtins
+ # "B", # flake8-bugbear:
+ # "C4", # flake8-comprehensions: catch incorrect use of comprehensions
+ # "DOC", # pydoclint
+ # "D", # pydocstyle
+ # "DTZ", # flake8-datetimez: strict timezone manipulation with datetime
+ # "E", # pycodestyle errors
+ # "FBT", # flake8-boolean-trap: detect boolean traps
+ # "ISC", # flake8-implicit-str-concat: good use of string concatenation
+ # "N", # pep8-naming: enforce naming conventions
+ # "NPY", # Some numpy-specific things
+ # "PL", # Pylint rules
+ # "PTH", # flake8-use-pathlib: use pathlib instead of os.path
+ # "RET", # flake8-return: good return practices
+ # "S", # flake8-bandit: security testing
+ # "SIM", # flake8-simplify: common simplification rules
+ # "TC", # flake8-type-checking: enforce importing certain types in a TYPE_CHECKING block
+ # "TD", # flake8-todo: Be diligent with TODO comments
+ # "UP", # pyupgrade: Warn if things can changed due to newer versions
+ # "W", # pycodestyle warnings
+]
ignore = [
# Conflicting lint rules with Ruff's formatter
# (see https://docs.astral.sh/ruff/formatter/#conflicting-lint-rules).
From 8bca401e32c739c88fe0028a282e76a340dc4b57 Mon Sep 17 00:00:00 2001
From: Frank Schneider
Date: Mon, 23 Jun 2025 13:09:29 +0200
Subject: [PATCH 119/123] Add pylint errors to linting rules
---
pyproject.toml | 8 ++++++--
1 file changed, 6 insertions(+), 2 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 976857e8c..b6cfa42cd 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -129,19 +129,23 @@ extend-select = [
"FA", # Enforce from __future__ import annotations
"I", # Isort rules
"ICN", # Use common import conventions
+ "PLE", # Pylint Errors
"TID", # Some good import practices
# "A", # flake8-builtins: detect shadowed builtins
# "B", # flake8-bugbear:
# "C4", # flake8-comprehensions: catch incorrect use of comprehensions
- # "DOC", # pydoclint
# "D", # pydocstyle
+ # "DOC", # pydoclint
# "DTZ", # flake8-datetimez: strict timezone manipulation with datetime
# "E", # pycodestyle errors
# "FBT", # flake8-boolean-trap: detect boolean traps
# "ISC", # flake8-implicit-str-concat: good use of string concatenation
# "N", # pep8-naming: enforce naming conventions
# "NPY", # Some numpy-specific things
- # "PL", # Pylint rules
+ # "PL", # All Pylint rules
+ # "PLC", # Pylint Convention
+ # "PLR", # Pylint Refactor
+ # "PLW", # Pylint Warnings
# "PTH", # flake8-use-pathlib: use pathlib instead of os.path
# "RET", # flake8-return: good return practices
# "S", # flake8-bandit: security testing
From 09aca7f501fe5cce5b76596ffaa04e34dfdbc726 Mon Sep 17 00:00:00 2001
From: Frank Schneider
Date: Wed, 25 Jun 2025 11:47:53 +0200
Subject: [PATCH 120/123] Fix formatting
---
algoperf/jax_utils.py | 149 ++++++++++++++++----------------
tests/test_jax_utils.py | 182 ++++++++++++++++++----------------------
2 files changed, 157 insertions(+), 174 deletions(-)
diff --git a/algoperf/jax_utils.py b/algoperf/jax_utils.py
index 28a4ba8c9..dab338328 100644
--- a/algoperf/jax_utils.py
+++ b/algoperf/jax_utils.py
@@ -1,95 +1,92 @@
from collections.abc import Sequence
import flax.linen as nn
-from flax.linen.module import compact
-from flax.linen.module import merge_param
-from flax.linen.module import Module
-from flax.typing import PRNGKey
import jax
-from jax import lax
-from jax import random
import jax.numpy as jnp
+from flax.linen.module import Module, compact, merge_param
+from flax.typing import PRNGKey
+from jax import lax, random
# Custom Layers
class Dropout(Module):
# pylint: disable=line-too-long
"""Create a dropout layer.
- Forked from
- https://flax-linen.readthedocs.io/en/latest/_modules/flax/linen/stochastic.html#Dropout.
- The reference dropout implementation is modified support changes
- to dropout rate during training by:
- 1) adding rate argument to the __call__ method.
- 2) removing the if-else condition to check for edge cases, which
- will trigger a recompile for jitted code.
-
- .. note::
- When using :meth:`Module.apply() `, make sure
- to include an RNG seed named ``'dropout'``. Dropout isn't necessary for
- variable initialization.
-
- Example usage::
-
- >>> import flax.linen as nn
- >>> import jax, jax.numpy as jnp
-
- >>> class MLP(nn.Module):
- ... @nn.compact
- ... def __call__(self, x, train):
- ... x = nn.Dense(4)(x)
- ... x = nn.Dropout(0.5, deterministic=not train)(x)
- ... return x
-
- >>> model = MLP()
- >>> x = jnp.ones((1, 3))
- >>> variables = model.init(jax.random.key(0), x, train=False) # don't use dropout
- >>> model.apply(variables, x, train=False) # don't use dropout
- Array([[-0.17875527, 1.6255447 , -1.2431065 , -0.02554005]], dtype=float32)
- >>> model.apply(variables, x, train=True, rngs={'dropout': jax.random.key(1)}) # use dropout
- Array([[-0.35751054, 3.2510893 , 0. , 0. ]], dtype=float32)
-
- Attributes:
- rate: the dropout probability. (_not_ the keep rate!)
- broadcast_dims: dimensions that will share the same dropout mask
- deterministic: if false the inputs are scaled by ``1 / (1 - rate)``
- and masked, whereas if true, no mask is applied and the inputs are
- returned as is.
- rng_collection: the rng collection name to use when requesting an rng
- key.
- """
+ Forked from
+ https://flax-linen.readthedocs.io/en/latest/_modules/flax/linen/stochastic.html#Dropout.
+ The reference dropout implementation is modified support changes
+ to dropout rate during training by:
+ 1) adding rate argument to the __call__ method.
+ 2) removing the if-else condition to check for edge cases, which
+ will trigger a recompile for jitted code.
+
+ .. note::
+ When using :meth:`Module.apply() `, make sure
+ to include an RNG seed named ``'dropout'``. Dropout isn't necessary for
+ variable initialization.
+
+ Example usage::
+
+ >>> import flax.linen as nn
+ >>> import jax, jax.numpy as jnp
+
+ >>> class MLP(nn.Module):
+ ... @nn.compact
+ ... def __call__(self, x, train):
+ ... x = nn.Dense(4)(x)
+ ... x = nn.Dropout(0.5, deterministic=not train)(x)
+ ... return x
+
+ >>> model = MLP()
+ >>> x = jnp.ones((1, 3))
+ >>> variables = model.init(jax.random.key(0), x, train=False) # don't use dropout
+ >>> model.apply(variables, x, train=False) # don't use dropout
+ Array([[-0.17875527, 1.6255447 , -1.2431065 , -0.02554005]], dtype=float32)
+ >>> model.apply(variables, x, train=True, rngs={'dropout': jax.random.key(1)}) # use dropout
+ Array([[-0.35751054, 3.2510893 , 0. , 0. ]], dtype=float32)
+
+ Attributes:
+ rate: the dropout probability. (_not_ the keep rate!)
+ broadcast_dims: dimensions that will share the same dropout mask
+ deterministic: if false the inputs are scaled by ``1 / (1 - rate)``
+ and masked, whereas if true, no mask is applied and the inputs are
+ returned as is.
+ rng_collection: the rng collection name to use when requesting an rng
+ key.
+ """
rate: float | None = None
broadcast_dims: Sequence[int] = ()
deterministic: bool | None = None
- rng_collection: str = "dropout"
+ rng_collection: str = 'dropout'
legacy: bool = False
@compact
def __call__(
- self,
- inputs,
- deterministic: bool | None = None,
- rate: float | None = None,
- rng: PRNGKey | None = None,
+ self,
+ inputs,
+ deterministic: bool | None = None,
+ rate: float | None = None,
+ rng: PRNGKey | None = None,
):
"""Applies a random dropout mask to the input.
- Args:
- inputs: the inputs that should be randomly masked.
- deterministic: if false the inputs are scaled by ``1 / (1 - rate)``
- and masked, whereas if true, no mask is applied and the inputs are
- returned as is.
- rate: the dropout probability. (_not_ the keep rate!)
- rng: an optional PRNGKey used as the random key, if not specified,
- one will be generated using ``make_rng`` with the
- ``rng_collection`` name.
-
- Returns:
- The masked inputs reweighted to preserve mean.
- """
- deterministic = merge_param("deterministic",
- self.deterministic,
- deterministic)
+ Args:
+ inputs: the inputs that should be randomly masked.
+ deterministic: if false the inputs are scaled by ``1 / (1 - rate)``
+ and masked, whereas if true, no mask is applied and the inputs are
+ returned as is.
+ rate: the dropout probability. (_not_ the keep rate!)
+ rng: an optional PRNGKey used as the random key, if not specified,
+ one will be generated using ``make_rng`` with the
+ ``rng_collection`` name.
+
+ Returns:
+ The masked inputs reweighted to preserve mean.
+ """
+ deterministic = merge_param(
+ 'deterministic', self.deterministic, deterministic
+ )
# Override self.rate if rate is passed to __call__
if rate is None:
@@ -121,10 +118,12 @@ def __call__(
def print_jax_model_summary(model, fake_inputs):
"""Prints a summary of the jax module."""
tabulate_fn = nn.tabulate(
- model,
- jax.random.PRNGKey(0),
- console_kwargs={
- "force_terminal": False, "force_jupyter": False, "width": 240
- },
+ model,
+ jax.random.PRNGKey(0),
+ console_kwargs={
+ 'force_terminal': False,
+ 'force_jupyter': False,
+ 'width': 240,
+ },
)
print(tabulate_fn(fake_inputs, train=False))
diff --git a/tests/test_jax_utils.py b/tests/test_jax_utils.py
index d54bf47aa..28e506400 100644
--- a/tests/test_jax_utils.py
+++ b/tests/test_jax_utils.py
@@ -5,14 +5,11 @@
from functools import partial
-from absl.testing import absltest
-from absl.testing import parameterized
import flax.linen as nn
import jax
import jax.numpy as jnp
-from jax.tree_util import tree_leaves
-from jax.tree_util import tree_map
-from jax.tree_util import tree_structure
+from absl.testing import absltest, parameterized
+from jax.tree_util import tree_leaves, tree_map, tree_structure
from algoperf.jax_utils import Dropout
@@ -22,9 +19,9 @@
def pytrees_are_equal(a, b, rtol=1e-5, atol=1e-8):
"""
- A custom function to check if two PyTrees are equal, handling floats with
- a tolerance.
- """
+ A custom function to check if two PyTrees are equal, handling floats with
+ a tolerance.
+ """
if tree_structure(a) != tree_structure(b):
return False
@@ -51,28 +48,22 @@ def __call__(self, x, train):
class DropoutModel(nn.Module):
-
@nn.compact
def __call__(self, x, train, dropout_rate=DEFAULT_DROPOUT):
- return Dropout(
- rate=dropout_rate, deterministic=not train)(
- x, rate=dropout_rate)
+ return Dropout(rate=dropout_rate, deterministic=not train)(
+ x, rate=dropout_rate
+ )
class DropoutTest(parameterized.TestCase):
-
@parameterized.named_parameters(
- dict(
- testcase_name="Dropout, p=0.0, train", dropout_rate=0.0,
- mode="train"),
- dict(testcase_name="Dropout, p=0.0, eval", dropout_rate=0.0, mode="eval"),
- dict(
- testcase_name="Dropout, p=0.1, train", dropout_rate=0.1,
- mode="train"),
- dict(testcase_name="Dropout, p=0.1, eval", dropout_rate=0.1, mode="eval"),
+ dict(testcase_name='Dropout, p=0.0, train', dropout_rate=0.0, mode='train'),
+ dict(testcase_name='Dropout, p=0.0, eval', dropout_rate=0.0, mode='eval'),
+ dict(testcase_name='Dropout, p=0.1, train', dropout_rate=0.1, mode='train'),
+ dict(testcase_name='Dropout, p=0.1, eval', dropout_rate=0.1, mode='eval'),
)
def test_forward(self, dropout_rate, mode):
- """ Compare forward pass of Dropout layer to flax.linen.Dropout in train and
+ """Compare forward pass of Dropout layer to flax.linen.Dropout in train and
eval mode.
"""
@@ -82,44 +73,39 @@ def test_forward(self, dropout_rate, mode):
orig_model = LegacyDropoutModel(dropout_rate=dropout_rate)
cust_model = DropoutModel()
- initial_variables_original = orig_model.init({"params": rng},
- fake_batch,
- train=False)
- initial_variables_custom = cust_model.init({"params": rng},
- fake_batch,
- train=False)
+ initial_variables_original = orig_model.init(
+ {'params': rng}, fake_batch, train=False
+ )
+ initial_variables_custom = cust_model.init(
+ {'params': rng}, fake_batch, train=False
+ )
assert pytrees_are_equal(
- initial_variables_original, initial_variables_custom, rtol=1e-6)
+ initial_variables_original, initial_variables_custom, rtol=1e-6
+ )
# forward pass
x = jnp.ones((10,))
- train = mode == "train"
+ train = mode == 'train'
y1 = orig_model.apply(
- initial_variables_original,
- x,
- train=train,
- rngs={"dropout": dropout_rng})
+ initial_variables_original, x, train=train, rngs={'dropout': dropout_rng}
+ )
y2 = cust_model.apply(
- initial_variables_custom,
- x,
- train=train,
- dropout_rate=dropout_rate,
- rngs={"dropout": dropout_rng},
+ initial_variables_custom,
+ x,
+ train=train,
+ dropout_rate=dropout_rate,
+ rngs={'dropout': dropout_rng},
)
assert jnp.allclose(y1, y2, atol=1e-3, rtol=1e-3)
@parameterized.named_parameters(
- dict(
- testcase_name="Dropout, p=0.0, train", dropout_rate=0.0,
- mode="train"),
- dict(testcase_name="Dropout, p=0.0, eval", dropout_rate=0.0, mode="eval"),
- dict(
- testcase_name="Dropout, p=0.1, train", dropout_rate=0.1,
- mode="train"),
- dict(testcase_name="Dropout, p=0.1, eval", dropout_rate=0.1, mode="eval"),
+ dict(testcase_name='Dropout, p=0.0, train', dropout_rate=0.0, mode='train'),
+ dict(testcase_name='Dropout, p=0.0, eval', dropout_rate=0.0, mode='eval'),
+ dict(testcase_name='Dropout, p=0.1, train', dropout_rate=0.1, mode='train'),
+ dict(testcase_name='Dropout, p=0.1, eval', dropout_rate=0.1, mode='eval'),
)
def test_dropout_update(self, dropout_rate, mode):
"""Call forward pass of Dropout layer with two different dropout rates
@@ -132,57 +118,51 @@ def test_dropout_update(self, dropout_rate, mode):
orig_model = LegacyDropoutModel(dropout_rate=dropout_rate)
cust_model = DropoutModel()
- initial_variables_original = orig_model.init({"params": rng},
- fake_batch,
- train=False)
+ initial_variables_original = orig_model.init(
+ {'params': rng}, fake_batch, train=False
+ )
- initial_variables_custom = cust_model.init({"params": rng},
- fake_batch,
- train=False)
+ initial_variables_custom = cust_model.init(
+ {'params': rng}, fake_batch, train=False
+ )
assert pytrees_are_equal(
- initial_variables_original, initial_variables_custom, rtol=1e-6)
+ initial_variables_original, initial_variables_custom, rtol=1e-6
+ )
# forward pass
x = jnp.ones((10,))
- train = mode == "train"
+ train = mode == 'train'
y1 = orig_model.apply(
- initial_variables_original,
- x,
- train=train,
- rngs={"dropout": dropout_rng})
+ initial_variables_original, x, train=train, rngs={'dropout': dropout_rng}
+ )
_ = cust_model.apply(
- initial_variables_custom,
- x,
- train=train,
- dropout_rate=0.9,
- rngs={"dropout": dropout_rng},
+ initial_variables_custom,
+ x,
+ train=train,
+ dropout_rate=0.9,
+ rngs={'dropout': dropout_rng},
)
y2 = cust_model.apply(
- initial_variables_custom,
- x,
- train=train,
- dropout_rate=dropout_rate,
- rngs={"dropout": dropout_rng},
+ initial_variables_custom,
+ x,
+ train=train,
+ dropout_rate=dropout_rate,
+ rngs={'dropout': dropout_rng},
)
assert jnp.allclose(y1, y2, atol=1e-3, rtol=1e-3)
@parameterized.named_parameters(
- dict(
- testcase_name="Dropout, p=0.0, train", dropout_rate=0.0,
- mode="train"),
- dict(testcase_name="Dropout, p=0.0, eval", dropout_rate=0.0, mode="eval"),
- dict(
- testcase_name="Dropout, p=0.1, train", dropout_rate=0.1,
- mode="train"),
- dict(testcase_name="Dropout, p=0.1, eval", dropout_rate=0.1, mode="eval"),
+ dict(testcase_name='Dropout, p=0.0, train', dropout_rate=0.0, mode='train'),
+ dict(testcase_name='Dropout, p=0.0, eval', dropout_rate=0.0, mode='eval'),
+ dict(testcase_name='Dropout, p=0.1, train', dropout_rate=0.1, mode='train'),
+ dict(testcase_name='Dropout, p=0.1, eval', dropout_rate=0.1, mode='eval'),
)
def test_jitted_updates(self, dropout_rate, mode):
- """ Compare jitted updates with dropout.
- """
+ """Compare jitted updates with dropout."""
# initialize models
rng, dropout_rng = jax.random.split(jax.random.key(SEED), 2)
@@ -190,42 +170,46 @@ def test_jitted_updates(self, dropout_rate, mode):
orig_model = LegacyDropoutModel(dropout_rate=dropout_rate)
cust_model = DropoutModel()
- initial_variables_original = orig_model.init({"params": rng},
- fake_batch,
- train=False)
- initial_variables_custom = cust_model.init({"params": rng},
- fake_batch,
- train=False)
+ initial_variables_original = orig_model.init(
+ {'params': rng}, fake_batch, train=False
+ )
+ initial_variables_custom = cust_model.init(
+ {'params': rng}, fake_batch, train=False
+ )
assert pytrees_are_equal(
- initial_variables_original, initial_variables_custom, rtol=1e-6)
+ initial_variables_original, initial_variables_custom, rtol=1e-6
+ )
# forward pass
x = jnp.ones((10,))
- train = mode == "train"
+ train = mode == 'train'
jitted_original_apply = jax.jit(
- partial(orig_model.apply), static_argnames=['train'])
+ partial(orig_model.apply), static_argnames=['train']
+ )
jitted_custom_apply = jax.jit(
- partial(cust_model.apply), static_argnames=['train'])
+ partial(cust_model.apply), static_argnames=['train']
+ )
for d in [i * 0.1 * dropout_rate for i in range(0, 11)]:
y1 = jitted_original_apply(
- initial_variables_original,
- x,
- train=train,
- rngs={"dropout": dropout_rng})
+ initial_variables_original,
+ x,
+ train=train,
+ rngs={'dropout': dropout_rng},
+ )
for d in [i * 0.1 * dropout_rate for i in range(0, 11)]:
y2 = jitted_custom_apply(
- initial_variables_custom,
- x,
- train=train,
- dropout_rate=d,
- rngs={"dropout": dropout_rng},
+ initial_variables_custom,
+ x,
+ train=train,
+ dropout_rate=d,
+ rngs={'dropout': dropout_rng},
)
assert jnp.allclose(y1, y2, atol=1e-3, rtol=1e-3)
-if __name__ == "__main__":
+if __name__ == '__main__':
absltest.main()
From 631f959dc2b63c4db1f0b6a314639add58321af1 Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Thu, 31 Jul 2025 16:17:47 +0000
Subject: [PATCH 121/123] add example sbatch script
---
scoring/utils/slurm/run_jobs.sh | 83 +++++++++++++++++++++++++++++++++
1 file changed, 83 insertions(+)
create mode 100644 scoring/utils/slurm/run_jobs.sh
diff --git a/scoring/utils/slurm/run_jobs.sh b/scoring/utils/slurm/run_jobs.sh
new file mode 100644
index 000000000..5fcf8f69e
--- /dev/null
+++ b/scoring/utils/slurm/run_jobs.sh
@@ -0,0 +1,83 @@
+#!/bin/bash
+
+#SBATCH --nodes=1 # give it a full node
+#SBATCH --ntasks-per-node=1
+#SBATCH --array=
+#SBATCH --partition=v100
+#SBATCH --gpus-per-node=8
+#SBATCH --exclusive #this will not allow other jobs to run on this cluster
+#SBATCH --output=experiments/tests/jit_debug_deepspeech_old_stephint_nadamw/job_%A_%a.out
+#SBATCH --error=experiments/tests/jit_debug_deepspeech_old_stephint_nadamw/job_%A_%a.err
+
+# Usage: sbatch .sh
+# This script reads config.json and launches a sbatch job using task
+# arrays where each job in the array corresponds to a training run
+# for a workload given a random seed and tuning trial index.
+# To generate the config.json use make_job_config.py.
+
+set -x
+
+# Pull docker image (ATTENTION: you may want to modify this)
+REPO=""
+IMAGE=""
+y | gcloud auth configure-docker $REPO
+docker pull $IMAGE
+# Job config (ATTENTION: you may want to modify this)
+config_file="" # Replace with your config file path
+LOGS_BUCKET="" # replace with your bucket used for logging
+
+
+# Function to read a JSON file and extract a value by key
+read_json_value() {
+ local json_file="$1"
+ local index="$2"
+ local key="$3"
+ local value=$(jq -r ".[\"$index\"].$key" "$json_file")
+ echo "$value"
+}
+
+# Check if jq is installed
+if ! command -v jq &> /dev/null
+then
+ echo "jq could not be found. Please install it."
+ exit 1
+fi
+
+TASK="$SLURM_ARRAY_TASK_ID"
+FRAMEWORK=$(read_json_value "$config_file" "$TASK" "framework")
+DATASET=$(read_json_value "$config_file" "$TASK" "dataset")
+SUBMISSION_PATH=$(read_json_value "$config_file" "$TASK" "submission_path")
+FRAMEWORK=$(read_json_value "$config_file" "$TASK" "framework")
+TUNING_SEARCH_SPACE=$(read_json_value "$config_file" "$TASK" "tuning_search_space")
+EXPERIMENT_DIR=$(read_json_value "$config_file" "$TASK" "experiment_dir")
+MAX_STEPS=$(read_json_value "$config_file" "$TASK" "max_steps")
+RNG_SEED=$(read_json_value "$config_file" "$TASK" "rng_seed")
+WORKLOAD=$(read_json_value "$config_file" "$TASK" "workload")
+HPARAM_START_INDEX=$(read_json_value "$config_file" "$TASK" "hparam_start_index")
+HPARAM_END_INDEX=$(read_json_value "$config_file" "$TASK" "hparam_end_index")
+NUM_TUNING_TRIALS=$(read_json_value "$config_file" "$TASK" "num_tuning_trials")
+TUNING_RULESET=$(read_json_value "$config_file" "$TASK" "tuning_ruleset")
+MAX_GLOBAL_STEPS=$(read_json_value "$config_file" "$MAX_GLOBAL_STEPS" "max_global_steps")
+
+docker run \
+ -v /opt/data/:/data/ \
+ -v $HOME/submissions_algorithms/:/algorithmic-efficiency/submissions_algorithms \
+ --gpus all \
+ --ipc=host \
+ $IMAGE \
+ -d $DATASET \
+ -f $FRAMEWORK \
+ -s $SUBMISSION_PATH \
+ -w $WORKLOAD \
+ -t $TUNING_SEARCH_SPACE \
+ -e $EXPERIMENT_DIR \
+ -c False \
+ -o True \
+ --rng_seed $RNG_SEED \
+ --hparam_start_index $HPARAM_START_INDEX \
+ --hparam_end_index $HPARAM_END_INDEX \
+ --num_tuning_trials $NUM_TUNING_TRIALS \
+ --tuning_ruleset $TUNING_RULESET \
+ --logs_bucket $LOGS_BUCKET \
+ -i true \
+ -r false
\ No newline at end of file
From db69fcea7174fb085547237bd09eb7b63fd2b789 Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Thu, 31 Jul 2025 20:59:57 +0000
Subject: [PATCH 122/123] remove index url for jax installation
---
docker/Dockerfile | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/docker/Dockerfile b/docker/Dockerfile
index 9926b0542..81fb22a6b 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -70,7 +70,7 @@ RUN if [ "$framework" = "jax" ] ; then \
echo "Installing Jax GPU" \
&& cd /algorithmic-efficiency \
&& pip install -e '.[pytorch_cpu]' -f https://download.pytorch.org/whl/cpu \
- && pip install -e '.[jax_gpu]' -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --pre ; \
+ && pip install -e '.[jax_gpu]'; \
elif [ "$framework" = "pytorch" ] ; then \
echo "Installing Pytorch GPU" \
&& cd /algorithmic-efficiency \
@@ -80,7 +80,7 @@ RUN if [ "$framework" = "jax" ] ; then \
echo "Installing Jax GPU and Pytorch GPU" \
&& cd /algorithmic-efficiency \
&& pip install -e '.[pytorch_gpu]' \
- && pip install -e '.[jax_gpu]' -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --pre ; \
+ && pip install -e '.[jax_gpu]'; \
else \
echo "Invalid build-arg $framework: framework should be either jax, pytorch or both." >&2 \
&& exit 1 ; \
From 51eb65fe9364780907fdc860df81015d4f392cff Mon Sep 17 00:00:00 2001
From: Priya Kasimbeg
Date: Thu, 31 Jul 2025 21:00:51 +0000
Subject: [PATCH 123/123] revert change in dockerfile on dev
---
docker/Dockerfile | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/docker/Dockerfile b/docker/Dockerfile
index 81fb22a6b..9926b0542 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -70,7 +70,7 @@ RUN if [ "$framework" = "jax" ] ; then \
echo "Installing Jax GPU" \
&& cd /algorithmic-efficiency \
&& pip install -e '.[pytorch_cpu]' -f https://download.pytorch.org/whl/cpu \
- && pip install -e '.[jax_gpu]'; \
+ && pip install -e '.[jax_gpu]' -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --pre ; \
elif [ "$framework" = "pytorch" ] ; then \
echo "Installing Pytorch GPU" \
&& cd /algorithmic-efficiency \
@@ -80,7 +80,7 @@ RUN if [ "$framework" = "jax" ] ; then \
echo "Installing Jax GPU and Pytorch GPU" \
&& cd /algorithmic-efficiency \
&& pip install -e '.[pytorch_gpu]' \
- && pip install -e '.[jax_gpu]'; \
+ && pip install -e '.[jax_gpu]' -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/ --pre ; \
else \
echo "Invalid build-arg $framework: framework should be either jax, pytorch or both." >&2 \
&& exit 1 ; \