# Electrode Masked Auto Encoder Training with Scenic Framework

1.   Implements a methods to read in a custom (electrodes) datasets, compiant with the Scenic framework.
2.   Adapts the ViT MAE structure to allow for single channel image inputs
3.   Adds custom processing to allow for non-square patches   


#### Colab Kernel (Electrodes Kernel)
Grants command for Access on Demand (AoD):

https://grants.corp.google.com/#/grants?request=20h%2Fchr-ards-electrodes-deid-colab-jobs&reason=b%2F314799341


In [None]:
# @title Imports

import functools
from typing import Any, Callable, Dict, Iterator, Tuple, Optional, Type, Union

from absl import logging
from clu import metric_writers
from clu import periodic_actions
from clu import platform
import flax
from flax import jax_utils
import flax.linen as nn
import jax
import jax.numpy as jnp
import jax.profiler
import ml_collections
import numpy as np
import optax
from colabtools import adhoc_import
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import tensorflow as tf
import tensorflow_datasets as tfds
import time
import tqdm

with adhoc_import.Google3():
  from scenic.dataset_lib import dataset_utils
  from scenic.google.xm import xm_utils
  from scenic.model_lib.base_models import base_model
  from scenic.model_lib.base_models import model_utils
  from scenic.projects.multimask.models import model_utils as mm_model_utils
  from scenic.model_lib.layers import nn_layers
  from scenic.model_lib.layers import nn_ops
  # To register the preprocessing ops
  from scenic.projects.multimask import data_utils  # pylint: disable=unused-import

  from scenic.train_lib import optax as scenic_optax
  from scenic.train_lib import pretrain_utils
  from scenic.train_lib import train_utils
  from scenic.train_lib.transfer import fewshot_utils
  from scenic.projects.multimask import trainer
  from scenic.projects.multimask.models import transformer_encoder
  from scenic.projects.baselines import vit
  from scenic.projects.multimask.models import vit_encoder
  from scenic.projects.multimask.models import vit_mae
  from scenic.projects.baselines.configs.google.common import common_fewshot

# Aliases for custom types:
Batch = Dict[str, jnp.ndarray]
MetricFn = Callable[
    [jnp.ndarray, jnp.ndarray, Dict[str, jnp.ndarray]],
    Dict[str, Tuple[float, int]],
]
LossFn = Callable[
    [jnp.ndarray, Batch, Optional[jnp.ndarray], jnp.ndarray], float
]
LrFns = Dict[str, Callable[[jnp.ndarray], jnp.ndarray]]
Patch = Union[Tuple[int, int], Tuple[int, int, int]]


## Sample dataset visualization and metrics

In [None]:
# @title Electrodes dataset meta data
dataset_name = 'lsm_prod/lsm_300min_100K_unimpute'
data_dir = '/namespace/fitbit-medical-sandboxes/partner/encrypted/chr-ards-electrodes/deid/raw/datasets/msa_1_5/lsm_tfds_datasets'
featurenames_csv = '/namespace/fitbit-medical-sandboxes/partner/encrypted/chr-ards-electrodes/deid/raw/datasets/msa_1_5/lsm_tfds_datasets/lsm_prod/lsm_300min_100K_unimpute/1.0.0/Dataset_FeatureNames.csv'

print('TF Dataset Information')
tfds.builder(dataset_name, data_dir=data_dir, try_gcs=True).info

In [None]:
# @title Load dataset
ds = tfds.load(dataset_name, data_dir=data_dir, split='train', shuffle_files=True)
assert isinstance(ds, tf.data.Dataset)
print(f'Dataset:\n{ds}\n')

with gfile.Open(featurenames_csv, 'r') as f:
  df = pd.read_csv(f)

features = df.columns
print(f'Features:\n{features}\n')

In [None]:
# @title Plot sample feature image

ds = ds.take(1)  # Only take a single example

for example in ds:  # example is `{'image': tf.Tensor, 'label': tf.Tensor}`
  inputs = example["input_signal"]
  label = example["label"]
  print('Example meta data:')
  print('Example Keys:', list(example.keys()))
  print('Inputs shape:', inputs.shape)
  print('Labels shape:', label)

  plt.figure(figsize=(15,10))
  imgplot = plt.imshow(np.swapaxes(inputs,0,1))
  plt.grid(None)
  plt.xlabel('Time (minutes)')
  plt.ylabel('Feature #')
  plt.show()

In [None]:
# @title Plot individual features

print('Printing feature rows separately...')
for example in ds:  # example is `{'image': tf.Tensor, 'label': tf.Tensor}`
  inputs = example["input_signal"]
  label = example["label"]

  fig, axs = plt.subplots(25, 1, figsize=(10,35))#, layout='constrained')
  for i, ax in enumerate(axs):
    ax.plot(inputs[:,i])
    ax.set_title(features[i])
    if i < len(axs) - 1:
      ax.get_xaxis().set_ticks([])

## Helpers (Functions and Classes)

In [None]:
# @title Trainer functions (adapted from trainer.py)

""" Trainer functions.

Adapted from google3/third_party/py/scenic/projects/multimask/trainer.py.

The below funcctions are adapted to allow for different input field names.
The original, expected field name was 'inputs', and it has here
been modified to 'input_signals'.
"""


def get_targets(batch: Batch, config: ml_collections.ConfigDict) -> jnp.ndarray:
  """Adapted from: google3/third_party/py/scenic/projects/multimask/trainer.py
  """

  targets_type = config.masked_feature_loss.targets_type
  if targets_type == 'rgb':
    return get_rgb_targets(batch['input_signal'], tuple(config.model.patches.size))
  elif targets_type == 'tokens':
    return nn.one_hot(batch['input_signal'], num_classes=config.model.vocab_size)
  else:
    raise ValueError(f'Unknown targets_type {targets_type}')


def get_rgb_targets(
    inputs: jnp.ndarray,
    patch_size: Patch,
    reconstruct_grayscale: Optional[bool] = False,
    standardise_per_patch: Optional[bool] = False,
) -> jnp.ndarray:
  """Get RGB targets to use for feature regression.

  Copied from: google3/third_party/py/scenic/projects/multimask/trainer.py

  Here, the targets are the raw rgb patches of the image.

  Args:
    inputs: Tensor of shape [b, h, w, c] or [b, t, h, w, c]. The former are
      images, and the later video.
    patch_size: The shape of the patch, defined as [ph, pw] for images, and [ph,
      pw, pt] for video.
    reconstruct_grayscale: If True, the target patch is in grayscale rather than
      rgb.
    standardise_per_patch: If true, standardise each patch by subtracting the
      mean and dividing by the standard deviation of that patch.

  Returns:
    Patched inputs. For images, shape is [b, gh * gw, ph * pw * c] where
      gh = h // ph and gw = w // pw.
      For video, shape is [b, gt * gh * gw, pt * ph * pw * c].
  """
  if inputs.ndim != 4:
    raise ValueError('Inputs should be 4D (images). Shape {inputs.shape}')

  if reconstruct_grayscale:
    # Reference for converting between RGB and grayscale.
    # https://en.wikipedia.org/wiki/Luma_%28video%29
    # Also used in tf.image.rgb_to_grayscale
    rgb_weights = jnp.tile(jnp.array([[0.2989, 0.5870, 0.1140]]), (3, 1)).T
    inputs = jnp.matmul(inputs, rgb_weights)

  assert inputs.ndim == 4, 'the input should shape BxHxWxC'
  batch = inputs.shape[0]
  # Shape is [batch, ht, wt, hp, wp, c]
  patched_image = nn_ops.patch_image(
      inputs, inputs_shape=None, patch_size=patch_size
  )
  num_tokens = patched_image.shape[1] * patched_image.shape[2]
  patched_input = jnp.reshape(patched_image, (batch, num_tokens, -1))

  if standardise_per_patch:
    patched_input = jax.nn.standardize(patched_input, axis=-1, epsilon=1e-6)

  return patched_input


# Forked from projects/mfp/trainer.py
def representation_fn(
    train_state: train_utils.TrainState,
    batch: Batch,
    *,
    flax_model: nn.Module,
    representation_layer: str,
    gather_to_host: bool = True
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
  """Feeds the inputs to the model and returns their representations.

  Adapted from: google3/third_party/py/scenic/projects/multimask/trainer.py

  Args:
    train_state: TrainState, the state of training including the current
      global_step, model_state, rng, and optimizer. The buffer of this argument
      can be donated to the computation.
    batch: A single batch of data from the dataset.
    flax_model: A Flax model.
    representation_layer: The name of the layer to use as the representation.
    gather_to_host: Whether to gather results from all devices to the host,
      rather than leaving them distributed.

  Returns:
    Representation learned by the model for the given inputs and the labels and
    masks. If `gather_to_host` is True, these are collected from all hosts.
  """
  variables = {'params': train_state.params, **train_state.model_state}

  _, aux = flax_model.apply(
      variables,
      batch['input_signal'],
      train=False,
      debug=False
  )
  representation = aux[representation_layer]

  if representation.ndim == 3:
    # Feature regression models return [batch, num_tokens, channels]
    logging.info('Representation shape before pooling tokens: %s',
                 representation.shape)
    representation = jnp.mean(representation, axis=1)
  logging.info('Representation shape: %s', representation.shape)

  if gather_to_host:
    representation = jax.lax.all_gather(representation, 'batch')
    batch = jax.lax.all_gather(batch, 'batch')
  return representation, batch['label'], batch['batch_mask']


# Forked from projects/baselines/plainvit/trainer.py
def train_step(
    train_state: train_utils.TrainState,
    batch: Batch,
    *,
    flax_model: nn.Module,
    loss_fn: LossFn,
    lr_fns: LrFns,
    metrics_fn: MetricFn,
    config: ml_collections.ConfigDict,
    debug: Optional[bool] = False,
) -> Tuple[
    train_utils.TrainState, Dict[str, Tuple[float, int]], Dict[str, Any]
]:
  """Runs a single step of training.

  Adapted from: google3/third_party/py/scenic/projects/multimask/trainer.py

  Given the state of the training and a batch of data, computes
  the loss and updates the parameters of the model.

  Note that in this code, the buffers of the first (train_state) and second
  (batch) arguments are donated to the computation.

  Args:
    train_state: The state of training including the current global_step,
      model_state, rng, params, and optimizer. The buffer of this argument can
      be donated to the computation.
    batch: A single batch of data. The buffer of this argument can be donated to
      the computation.
    flax_model: A Flax model.
    loss_fn: A loss function that given logits, a batch, and parameters of the
      model calculates the loss.
    lr_fns: The learning rate fns used for the optimizer in train_state.
    metrics_fn: A metrics function that given logits and batch of data,
      calculates the metrics as well as the loss.
    config: Configurations of the experiment.
    debug: Whether the debug mode is enabled during training. `debug=True`
      enables model specific logging/storing some values using
      jax.host_callback.

  Returns:
    Updated state of training and computed metrics and some training logs.
  """
  training_logs = {}
  new_rng, rng = jax.random.split(train_state.rng)

  # Bind the rng to the host/device we are on.
  dropout_rng = train_utils.bind_rng_to_host_device(
      rng, axis_name='batch', bind_to='device'
  )

  # Add prediction targets
  batch['targets'] = get_targets(batch, config)

  def training_loss_fn(params):
    variables = {'params': params, **train_state.model_state}
    (logits, aux), new_model_state = flax_model.apply(
        variables,
        batch['input_signal'],
        mutable=['batch_stats'],
        train=True,
        rngs={'dropout': dropout_rng},
        debug=debug,
    )
    loss = loss_fn(logits, batch, variables['params'], aux['token_mask'])
    return loss, (new_model_state, logits, aux['token_mask'])

  compute_gradient_fn = jax.value_and_grad(training_loss_fn, has_aux=True)
  (train_cost, (new_model_state, logits, masks)), grad = compute_gradient_fn(
      train_state.params
  )

  del train_cost
  # Re-use same axis_name as in the call to `pmap(...train_step...)` below.
  grad = jax.lax.pmean(grad, axis_name='batch')

  updates, new_opt_state = train_state.tx.update(
      grad, train_state.opt_state, train_state.params
  )
  new_params = optax.apply_updates(train_state.params, updates)

  training_logs['l2_grads'] = jnp.sqrt(
      sum([jnp.vdot(g, g) for g in jax.tree_util.tree_leaves(grad)])
  )
  ps = jax.tree_util.tree_leaves(new_params)
  training_logs['l2_params'] = jnp.sqrt(sum([jnp.vdot(p, p) for p in ps]))
  us = jax.tree_util.tree_leaves(updates)
  training_logs['l2_updates'] = jnp.sqrt(sum([jnp.vdot(u, u) for u in us]))
  for name, lr_fn in lr_fns.items():
    lr_name = 'learning_rate' if name == 'all' else f'learning_rate_{name}'
    training_logs[lr_name] = lr_fn(train_state.global_step)

  metrics = metrics_fn(logits, masks, batch)

  new_train_state = train_state.replace(  # pytype: disable=attribute-error
      global_step=train_state.global_step + 1,
      opt_state=new_opt_state,
      params=new_params,
      model_state=new_model_state,
      rng=new_rng,
  )

  return new_train_state, metrics, training_logs


def eval_step(
    train_state: train_utils.TrainState,
    batch: Batch,
    *,
    flax_model: nn.Module,
    metrics_fn: MetricFn,
    config: ml_collections.ConfigDict,
    debug: Optional[bool] = False,
) -> Tuple[Dict[str, Tuple[float, int]], jnp.ndarray]:
  """Runs a single step of training.

  Adapted from: google3/third_party/py/scenic/projects/multimask/trainer.py

  Note that in this code, the buffer of the second argument (batch) is donated
  to the computation.

  Assumed API of metrics_fn is:
  ```metrics = metrics_fn(logits, batch)
  where batch is yielded by the batch iterator, and metrics is a dictionary
  mapping metric name to a vector of per example measurements. eval_step will
  aggregate (by summing) all per example measurements and divide by the
  aggregated normalizers. For each given metric we compute:
  1/N sum_{b in batch_iter} metric(b), where  N is the sum of normalizer
  over all batches.

  Args:
    train_state: TrainState, the state of training including the current
      global_step, model_state, rng, params and optimizer state. The buffer of
      this argument can be donated to the computation.
    batch: A single batch of data. a metrics function, that given logits and
      batch of data, calculates the metrics as well as the loss.
    flax_model: A Flax model.
    metrics_fn: A metrics function, that given logits and batch of data,
      calculates the metrics as well as the loss.
    config: Configurations of the experiment.
    debug: Whether the debug mode is enabled during evaluation. `debug=True`
      enables model specific logging/storing some values using
      jax.host_callback.

  Returns:
    Calculated metrics and logits.
  """
  # Add prediction targets
  batch['targets'] = get_targets(batch, config)

  # Always use the same seed, so that eval is as consistent as possible
  rng = jax.random.PRNGKey(config.rng_seed)

  # Bind the rng to the host/device we are on.
  dropout_rng = train_utils.bind_rng_to_host_device(
      rng, axis_name='batch', bind_to='device'
  )

  variables = {'params': train_state.params, **train_state.model_state}
  logits, aux = flax_model.apply(
      variables,
      batch['input_signal'],
      train=True,  # so that masking is enabled
      mutable=False,
      rngs={'dropout': dropout_rng},
      debug=debug,
  )
  metrics = metrics_fn(logits, aux['token_mask'], batch)
  return metrics, logits


# Overwrite package functions with functions that allow for an input of key name
# 'input_signal'
trainer.get_targets = get_targets
trainer.representation_fn = representation_fn
trainer.train_step = train_step
trainer.eval_step = eval_step

In [None]:
# @title Single channel ViT MAE (adapted from vit_mae.py)

"""ViT encoder-decoder models for mmultimask.

Adapted from google3/third_party/py/scenic/projects/multimask/models/vit_mae.py.
The below functions/classes are adapted to allow for single channel input
images, where as the original implementation was for 3-channel RGB images.
"""

# Mostly copied from ViTMaskedAutoencoder in projects/mfp/vit.py
class ViTMAE(nn.Module):
  """Encoder-decoder Vision Transformer model for masked feature regression.

  Copied from google3/third_party/py/scenic/projects/multimask/models/vit_mae.py.

  The differences to `ViTMaskedModel` from vit_encoder.py are that:
  -- Only non-masked tokens are processed by the encoder
  -- The parallel decoder then processes all tokens

  Attributes:
    num_classes: Number of output classes.
    mlp_dim: Dimension of the mlp on top of attention block.
    num_layers: Number of layers.
    num_heads: Number of self-attention heads.
    patches: Configuration of the patches extracted in the stem of the model.
    hidden_size: Size of the hidden state of the output of model's stem.
    token_mask_probability: Probability of masking out the input tokens (with a
      learned mask token) during training.
    representation_size: Size of the representation layer in the model's head.
      if None, we skip the extra projection + tanh activation at the end.
    dropout_rate: Dropout rate.
    attention_dropout_rate: Dropout for attention heads.
    stochastic_depth: Probability of dropping out a layer during training.
    classifier: type of the classifier layer. Options are 'gap', 'gmp', 'gsp',
      'token'.
    dtype: JAX data type for activations.
  """

  num_classes: int
  mlp_dim: int
  num_layers: int
  num_heads: int
  patches: ml_collections.ConfigDict
  hidden_size: int
  token_mask_probability: str
  decoder_config: ml_collections.ConfigDict
  representation_size: Optional[int] = None
  positional_embedding: str = 'sinusoidal_2d'
  positional_embedding_decoder: str = 'sinusoidal_2d'
  dropout_rate: float = 0.1
  attention_dropout_rate: float = 0.1
  stochastic_depth: float = 0.0
  classifier: str = 'none'
  dtype: jnp.dtype = jnp.float32

  @nn.compact
  def __call__(self, x: jnp.ndarray, *, train: bool, debug: bool = False):
    """Forward pass of Vision Transformer."""

    fh, fw = self.patches.size
    assert x.shape[1] % fh == 0 and x.shape[2] % fw == 0, (
        'Height and width should be divisible by the respective patch sizes,'
        f' instead got {x.shape[1:3]} and {(fh, fw)}'
    )
    # Extracting patches and then embedding is in fact a single convolution.
    x = nn.Conv(
        self.hidden_size,
        (fh, fw),
        strides=(fh, fw),
        padding='VALID',
        name='embedding',
    )(x)
    batch, height, width, channels = x.shape
    x = jnp.reshape(x, [batch, height * width, channels])

    # Add positional encodings before removing the masked tokens
    x = mm_model_utils.add_positional_embeddings(
        x, self.positional_embedding, [batch, height, width, channels]
    )

    # Remove masked tokens if needed
    n_tokens = height * width
    if train:
      # Generate mask indices.
      assert self.token_mask_probability.startswith('constant_'), (
          'Only constant token_mask_probability supported in MAE, instad got'
          f' {self.token_mask_probability}'
      )
      token_mask_probability = float(self.token_mask_probability.split('_')[1])
      n_masked = int(token_mask_probability * n_tokens)
      mask_indices, unmasked_indices, token_mask = (
          mm_model_utils.get_mask_indices(
              batch, n_tokens, n_masked, self.make_rng('dropout')
          )
      )

      # Process only unmasked tokens with the encoder.
      batch_indices = jnp.arange(batch).reshape(batch, 1)
      x = x[batch_indices, unmasked_indices]
    else:
      token_mask = jnp.zeros((batch, n_tokens))

    aux = {'token_mask': token_mask}

    # If we want to add a class token, add it here.
    # Note that in MAE, positional encodings are not added to the CLS token.
    if self.classifier == 'token':
      cls = self.param('cls', nn.initializers.zeros,
                       (1, 1, channels), x.dtype)
      cls = jnp.tile(cls, [batch, 1, 1])
      x = jnp.concatenate([cls, x], axis=1)

    x = vit.Encoder(
        mlp_dim=self.mlp_dim,
        num_layers=self.num_layers,
        num_heads=self.num_heads,
        dropout_rate=self.dropout_rate,
        attention_dropout_rate=self.attention_dropout_rate,
        stochastic_depth=self.stochastic_depth,
        dtype=self.dtype,
        positional_embedding='none',  # Has already been added.
        name='Transformer')(
            x, train=train)
    aux['pre_logits'] = x

    # If not training, skip decoding
    if not train:
      return x, aux

    # Process entire sequence with the decoder.
    mask_token = self.param('mask_token',
                            nn.initializers.zeros,
                            (1, 1, self.decoder_config.hidden_size))

    x = nn.Dense(
        self.decoder_config.hidden_size,
        kernel_init=nn.initializers.xavier_uniform(),
        name='decoder_projection')(x)
    if self.classifier == 'token':
      x = x[:, 1:, :]

    # This effectively "unshuffles" the tokens. This means that we can simply
    # add positional encodings in the decoder without having to worry about
    # their ordering.
    x_all = jnp.zeros((batch, n_tokens, self.decoder_config.hidden_size))
    x_all = x_all.at[batch_indices, unmasked_indices].set(x)
    x_all = x_all.at[batch_indices, mask_indices].set(mask_token)
    x = x_all
    del x_all

    # Add positional encodings to the decoder.
    x = mm_model_utils.add_positional_embeddings(
        x, self.positional_embedding_decoder,
        [batch, height, width, self.decoder_config.hidden_size])

    # The parallel decoder, which is actually technically an encoder
    x = vit.Encoder(
        mlp_dim=self.decoder_config.mlp_dim,
        num_layers=self.decoder_config.num_layers,
        num_heads=self.decoder_config.num_heads,
        dropout_rate=self.decoder_config.dropout_rate,
        attention_dropout_rate=self.decoder_config.attention_dropout_rate,
        stochastic_depth=self.decoder_config.get('stochastic_depth', 0.0),
        dtype=self.dtype,
        positional_embedding='none',  # Has already been added.
        name='Decoder')(x, train=train)

    # Predict pixel reconstructions.
    if self.representation_size is not None:
      x = nn.Dense(self.representation_size, name='pre_logits')(
          x)
      x = nn.tanh(x)
    else:
      x = nn_layers.IdentityLayer(name='pre_logits')(x)
    aux['pre_logits_decoder'] = x
    x = nn.Dense(
        self.num_classes,
        kernel_init=nn.initializers.zeros,
        name='output_projection')(x)

    return x, aux


# (metric, normalizer, apply_prediction_weights)
# Copied from google3/third_party/py/scenic/projects/multimask/models/vit_mae.py.
_REGRESSION_METRICS = {
    'mean_squared_error_all': (
        functools.partial(mm_model_utils.weighted_error, loss_type='squared'),
        model_utils.num_examples,
        False,
    ),
    'mean_absolute_error_all': (
        functools.partial(mm_model_utils.weighted_error, loss_type='absolute'),
        model_utils.num_examples,
        False,
    ),
    'mean_squared_error_masked': (
        functools.partial(mm_model_utils.weighted_error, loss_type='squared'),
        model_utils.num_examples,
        True,
    ),
    'mean_absolute_error_masked': (
        functools.partial(mm_model_utils.weighted_error, loss_type='absolute'),
        model_utils.num_examples,
        True,
    ),
}


def regression_metrics_function(
    predictions: jnp.ndarray,
    prediction_masks: jnp.ndarray,
    batch: base_model.Batch,
    metrics: base_model.MetricNormalizerFnDict,
    axis_name: Union[str, Tuple[str, ...]] = 'batch',
) -> Dict[str, Tuple[float, int]]:
  """Calculate metrics for the regression task.

  Copied from google3/third_party/py/scenic/projects/multimask/models/vit_mae.py.

  Currently we assume each metric_fn has the API:
    ```metric_fn(predictions, targets, weights)```
  and returns an array of shape [batch,]. We also assume that to compute
  the aggregate metric, one should sum across all batches, then divide by the
  total samples seen. In this way we currently only support metrics of the 1/N
  sum f(inputs, targets). Note, the caller is responsible for dividing by
  the normalizer when computing the mean of each metric.

  Args:
   predictions: Output of model in shape [batch, length, channels].
   prediction_masks: Masks used for masked modeling, shape [batch, length]
   batch: Batch (dict) with keys 'targets' and optionally 'batch_mask'.
   metrics: The regression metrics to evaluate. The key is the name of the
     metric, and the value is the metrics function, normalizer, and a bool
     indicating whether to apply prediction_masks.
   axis_name: List of axes on which we run the pmsum.

  Returns:
    A dict of metrics, in which keys are metrics name and values are tuples of
    (metric, normalizer).
  """
  targets = batch['targets']
  batch_weights = batch.get('batch_mask')
  weights = jnp.expand_dims(batch_weights, axis=-1) * prediction_masks
  evaluated_metrics = {}
  for key, val in metrics.items():
    curr_weights = weights if val[2] else batch_weights
    evaluated_metrics[key] = model_utils.psum_metric_normalizer(
        (
            val[0](
                targets,
                predictions,  # pytype: disable=wrong-arg-types  # jax-ndarray
                curr_weights,
            ),
            val[1](
                targets,
                predictions,  # pytype: disable=wrong-arg-types  # jax-ndarray
                batch_weights,
            ),
        ),
        axis_name=axis_name,
    )
  return evaluated_metrics  # pytype: disable=bad-return-type  # jax-ndarray


class ViTMAESingleChannelModel(base_model.BaseModel):
  """ViT-based masked modeling.

  Adapted from google3/third_party/py/scenic/projects/multimask/models/vit_mae.py.
  """

  def build_flax_model(self) -> nn.Module:
    model_dtype = getattr(jnp, self.config.get('model_dtype_str', 'float32'))
    num_classes = np.prod(tuple(self.config.model.patches.size)) * 1

    return ViTMAE(
        num_classes=num_classes,
        mlp_dim=self.config.model.mlp_dim,
        num_layers=self.config.model.num_layers,
        num_heads=self.config.model.num_heads,
        representation_size=self.config.model.representation_size,
        positional_embedding=self.config.model.positional_embedding,
        positional_embedding_decoder=self.config.model.positional_embedding_decoder,
        patches=self.config.model.patches,
        hidden_size=self.config.model.hidden_size,
        token_mask_probability=(
            self.config.masked_feature_loss.token_mask_probability
        ),
        classifier='none',
        dropout_rate=self.config.model.get('dropout_rate', 0.1),
        attention_dropout_rate=self.config.model.get(
            'attention_dropout_rate', 0.1
        ),
        stochastic_depth=self.config.model.get('stochastic_depth', 0.0),
        decoder_config=self.config.model.decoder_config,
        dtype=model_dtype,
    )

  def default_flax_model_config(self) -> ml_collections.ConfigDict:
    return ml_collections.ConfigDict()

  def init_from_train_state(
      self,
      train_state: Any,
      restored_train_state: Any,
      restored_model_cfg: ml_collections.ConfigDict,
  ) -> Any:
    """Updates the train_state with data from restored_train_state.

    This function is writen to be used for 'fine-tuning' experiments. Here, we
    do some surgery to support larger resolutions (longer sequence length) in
    the transformer block, with respect to the learned pos-embeddings.

    Args:
      train_state: A raw TrainState for the model.
      restored_train_state: A TrainState that is loaded with parameters/state of
        a  pretrained model.
      restored_model_cfg: Configuration of the model from which the
        restored_train_state come from. Usually used for some asserts.

    Returns:
      Updated train_state.
    """
    return vit.init_vit_from_train_state(
        train_state, restored_train_state, self.config, restored_model_cfg
    )

  # prediction_masks at the last position to fit the parent class func signature
  def loss_function(
      self,
      predictions: jnp.ndarray,
      batch: base_model.Batch,
      model_params: Optional[jnp.ndarray] = None,
      prediction_masks: Optional[jnp.ndarray] = None,
  ) -> float:
    """Returns the (weighted) mean squared error.

    Args:
      predictions: Output of model in shape [batch, length, channels].
      batch: Batch (dict) with keys 'targets' and optionally 'batch_mask'.
      model_params: Parameters of the model, for optionally applying
        regularization.
      prediction_masks: Masks used for masked modeling, shape [batch, length]

    Returns:
      The scalar loss, which is the (weighted) absolute error.
    """
    # IIUC, this mask can be provided by the data loader to indicate invalid
    # examples, e.g. for incomplete batches during eval
    weights = batch['batch_mask']  # shape (batch_size,)

    # If requested, compute the loss only on unmasked tokens
    if self.config.masked_feature_loss.loss_only_masked_tokens:
      weights = jnp.expand_dims(weights, axis=-1) * prediction_masks

    targets = batch['targets']

    total_loss = mm_model_utils.weighted_error(
        predictions,
        targets,
        weights,
        axis=tuple(range(targets.ndim)),  # aggregate over the batch axis too
        loss_type=self.config.masked_feature_loss.loss_type,
        mean=True,
    )
    return total_loss  # pytype: disable=bad-return-type  # jax-ndarray

  def get_metrics_fn(self, split: Optional[str] = None) -> base_model.MetricFn:
    """Returns a callable metric function for the model.

    By default, we return the same metric for each split.

    Args:
      split: The split for which we calculate the metrics. It should be one of
        the ['train',  'validation', 'test'].
    Returns: A metric function with the following API:
      ```metrics_fn(predictions, batch)```
    """

    del split  # Same function for all splits.
    return functools.partial(
        regression_metrics_function, metrics=_REGRESSION_METRICS
    )


vit_mae.ViTMAESingleChannelModel = ViTMAESingleChannelModel

In [None]:
# @title Model Class and Trainer Selector Functions

""" Adapted from google3/third_party/py/scenic/projects/multimask/main.py
"""

def get_model_cls(model_name: str):
  """Get the model class for the Multimask project."""
  if model_name == 'vit_masked_encoder':
    return vit_encoder.VitMaskedEncoderModel
  elif model_name == 'vit_mae':
    return vit_mae.ViTMAEModel
  elif model_name == 'vit_mae_single_channel':
    print('Model: ViTMAESingleChannelModel')
    return vit_mae.ViTMAESingleChannelModel
  elif model_name == 'transformer_masked_encoder':
    return transformer_encoder.TransformerMaskedEncoderModel
  else:
    raise ValueError(f'Unrecognized model: {model_name}.')

def get_train_fn(trainer_name):
  if trainer_name == 'multimask_trainer':
    return trainer.train
  else:
    raise ValueError(f'Unrecognized trainer: {trainer_name}.')


In [None]:
# @title Electrode Model Dataloader

""" Adapted from a combination of the following files:
google3/third_party/py/scenic/dataset_lib/cifar10_dataset.py
google3/third_party/py/scenic/dataset_lib/dataset_utils.py
"""

def preprocess_example(example, patch_size=None, dtype=tf.float32):
  """Preprocesses the given example.

  Adapted from google3/third_party/py/scenic/dataset_lib/cifar10_dataset.py

  Args:
    example: dict; Example that has an 'image' and a 'label'.
    dtype: Tensorflow data type; Data type of the image.

  Returns:
    A preprocessed example.

  NOTE: This assumes that the image is in the shape [H, W, C],
    where H is the Time axis, and W is the feature axis.
  """
  image = tf.cast(example['input_signal'], dtype=dtype)
  H, W, C = image.shape

  # Crop and pad image to allow for square patches
  if patch_size is not None:
    patch_h, patch_w = patch_size

    # Crop H to make even patches
    num_patches_h = H // patch_h
    crop_h = H - int(num_patches_h * patch_h)
    image = image[crop_h:, :, :]

    # Pad W to make even patches
    remainder_w = W % patch_w
    if remainder_w != 0:
      pad_total = patch_w - remainder_w
      pad1 = pad_total // 2
      pad2 = pad_total - pad1
      paddings = [[0, 0], [pad1, pad2], [0, 0]]
      image = tf.pad(image, paddings, mode='CONSTANT', constant_values=0)

  # Return preprocessed feature
  return {'input_signal': image}


# TODO(girishvn): define augmentation function
# def augment_example(example, dtype=tf.float32, data_augmentations=None):
#   """Apply data augmentation on the given training example.


def get_electrodes_dataset(
    *,
    config,
    num_shards,
    batch_size,
    eval_batch_size=None,
    dtype_str='float32',
    shuffle_seed=0,
    rng=None,
    shuffle_buffer_size=None,
    dataset_configs=None,
    dataset_service_address: Optional[str] = None,
    cache = False,
    dataset_name='lsm_prod/lsm_300min_100K_unimpute',
    data_dir='/namespace/fitbit-medical-sandboxes/partner/encrypted/chr-ards-electrodes/deid/raw/datasets/msa_1_5/lsm_tfds_datasets'
):
  """ Gets and formats the Electrodes dataset.

  Adapted from:
  google3/third_party/py/scenic/dataset_lib/cifar10_dataset.py and
  google3/third_party/py/scenic/dataset_lib/dataset_utils.py.
  """

  del rng
  dtype = getattr(tf, dtype_str)  # data dtype
  p_idx = jax.process_index()  # current process index
  p_cnt = jax.process_count()  # process count (number of processes)
  if eval_batch_size is None: eval_batch_size = batch_size  # set eval batch

  # Setup split preprocessing functions.
  train_preprocess_fn = functools.partial(
      preprocess_example, patch_size=config.model.patches.size, dtype=dtype)
  eval_preprocess_fn = functools.partial(
      preprocess_example, patch_size=config.model.patches.size, dtype=dtype)

  # Setup augmentation functions.
  # TODO(girishvn): Set up data augmentations here.

  # Create dataset splits (even splits per worker).
  train_split_range = tfds.even_splits(split='train', n=p_cnt)[p_idx]
  eval_split_range = tfds.even_splits(split='test', n=p_cnt)[p_idx]

  # Load tf dataset.
  train_ds = tfds.load(
      dataset_name, data_dir=data_dir, split=train_split_range,
      shuffle_files=False
  )
  eval_ds = tfds.load(
      dataset_name, data_dir=data_dir, split=eval_split_range,
      shuffle_files=False
  )
  logging.info(
      f'Loaded train and eval split {p_idx}/{p_cnt} from {dataset_name}.'
  )

  # Enable multi threaded workers.
  train_options = tf.data.Options()
  train_options.threading.private_threadpool_size = 48
  train_ds = train_ds.with_options(train_options)

  eval_options = tf.data.Options()
  eval_options.threading.private_threadpool_size = 48
  eval_ds = eval_ds.with_options(eval_options)

  # Applying preprocessing before `ds.cache()` to re-use it.
  train_ds = train_ds.map(
      train_preprocess_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE
  )
  eval_ds = train_ds.map(
      eval_preprocess_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE
  )

  # Cache datasets.
  if cache:
    train_ds = train_ds.cache()
    eval_ds = eval_ds.cache()

  train_ds = train_ds.repeat()  # repeat
  # TODO(girishvn): add augmentations
  shuffle_buffer_size = shuffle_buffer_size  or (8 * batch_size)
  train_ds = train_ds.shuffle(shuffle_buffer_size, seed=shuffle_seed)  # shuffle
  train_ds = train_ds.batch(batch_size, drop_remainder=True)  # batch
  train_ds = train_ds.prefetch(tf.data.experimental.AUTOTUNE)  # prefetch

  eval_ds = eval_ds.batch(batch_size, drop_remainder=True)  # batch
  eval_ds = eval_ds.repeat()  # repeat
  eval_ds = eval_ds.prefetch(tf.data.experimental.AUTOTUNE)

  if dataset_service_address:
    if shuffle_seed is not None:
      raise ValueError('Using dataset service with a random seed causes each '
                       'worker to produce exactly the same data. Add '
                       'config.shuffle_seed = None to your config if you '
                       'want to run with dataset service.')
    train_ds = dataset_utils.distribute(train_ds, dataset_service_address)
    logging.info('Using the tf.data service at %s', dataset_service_address)

  # Other mappings
  maybe_pad_batches_train = functools.partial(
      dataset_utils.maybe_pad_batch, train=True,
      batch_size=batch_size, inputs_key='input_signal')
  maybe_pad_batches_eval = functools.partial(
      dataset_utils.maybe_pad_batch, train=False,
      batch_size=eval_batch_size, inputs_key='input_signal')
  shard_batches = functools.partial(dataset_utils.shard, n_devices=num_shards)

  # Iter dataset
  train_iter = iter(train_ds)
  train_iter = map(dataset_utils.tf_to_numpy, train_iter)
  train_iter = map(maybe_pad_batches_train, train_iter)
  train_iter = map(shard_batches, train_iter)

  eval_iter = iter(eval_ds)
  eval_iter = map(dataset_utils.tf_to_numpy, eval_iter)
  eval_iter = map(maybe_pad_batches_eval, eval_iter)
  eval_iter = map(shard_batches, eval_iter)

  # Save meta data
  info = tfds.builder(dataset_name, data_dir=data_dir, try_gcs=True).info
  input_shape = tuple([-1] + list(info.features['input_signal'].shape))
  meta_data = {
      'input_shape': input_shape,
      'num_train_examples': dataset_utils.get_num_examples(
          dataset=dataset_name, split='train', data_dir=data_dir),
      'num_eval_examples': dataset_utils.get_num_examples(
          dataset=dataset_name, split='test', data_dir=data_dir),
      'input_dtype': getattr(jnp, dtype_str),
      'target_is_onehot': False,
      'num_classes': None,
  }

  # Return dataset structure.
  return dataset_utils.Dataset(train_iter, eval_iter, None, meta_data)


def get_dataset(
    config: ml_collections.ConfigDict,
    data_rng: jnp.ndarray,
    *,
    num_local_shards: Optional[int] = None,
    dataset_service_address: Optional[str] = None,
    dataset_name: Optional[str] = None,
    dataset_configs: Optional[ml_collections.ConfigDict] = None,
    **kwargs: Any,
) -> dataset_utils.Dataset:

  """ Adapted from: google3/third_party/py/scenic/train_lib/train_utils.py
  """

  # Get device count
  device_count = jax.device_count()
  logging.info('device_count: %d', device_count)
  logging.info('num_hosts : %d', jax.process_count())
  logging.info('host_id : %d', jax.process_index())

  # Set the dataset builder functions
  dataset_name = dataset_name or config.dataset_name
  if dataset_name == 'lsm_prod/lsm_300min_100K_unimpute':
    dataset_builder = get_electrodes_dataset
  else:
    raise ValueError(f'Dataset {dataset_name} is not supported.')

  # Get batch size
  batch_size = config.batch_size
  if batch_size % device_count > 0:
    raise ValueError(
        f'Batch size ({batch_size}) must be divisible by the '
        f'number of devices ({device_count})'
    )

  local_batch_size = batch_size // jax.process_count()
  device_batch_size = batch_size // device_count
  logging.info('local_batch_size : %d', local_batch_size)
  logging.info('device_batch_size : %d', device_batch_size)

  # Get shuffle seed - ensure it exists
  shuffle_seed = config.get('shuffle_seed', None)
  if dataset_service_address and shuffle_seed is not None:
    raise ValueError(
        'Using dataset service with a random seed causes each '
        'worker to produce exactly the same data. Add '
        'config.shuffle_seed = None to your config if you want '
        'to run with dataset service.'
    )

  dataset_configs = dataset_configs or config.get('dataset_configs', {})
  num_local_shards = num_local_shards or jax.local_device_count()

  # Build the dataset
  ds = dataset_builder(
      config=config,
      num_shards=num_local_shards,
      batch_size=local_batch_size,
      dtype_str=config.data_dtype_str,
      shuffle_seed=shuffle_seed,
      rng=data_rng,
      dataset_configs=dataset_configs,
      dataset_service_address=dataset_service_address,
      **kwargs,
  )

  return ds


## Training and eval pipeline

In [None]:
#@title Config

NUM_TRAIN_STEPS = 5000

# Model variant / patch H (time steps) / patch W (features)
VARIANT = 'TiShallow/10/5'

LRS = [1e-3]
TOKEN_MASK_PROB = 'constant_0.8'

HIDDEN_SIZES = {
    'Deb': 16,
    'Ti': 192,
    'TiShallow': 192,
    'S': 384,
    'SShallow': 384,
    'M': 512,
    'B': 768,
    'L': 1024,
    'H': 1280,
    'g': 1408,
    'G': 1664,
    'e': 1792,
}
MLP_DIMS = {
    'Deb': 32,
    'Ti': 768,
    'TiShallow': 768,
    'S': 1536,
    'SShallow': 1536,
    'M': 2048,
    'B': 3072,
    'L': 4096,
    'H': 5120,
    'g': 6144,
    'G': 8192,
    'e': 15360,
}
NUM_HEADS = {
    'Deb': 2,
    'Ti': 3,
    'TiShallow': 3,
    'S': 6,
    'SShallow': 6,
    'M': 8,
    'B': 12,
    'L': 16,
    'H': 16,
    'g': 16,
    'G': 16,
    'e': 16,
}
NUM_LAYERS = {
    'Deb': 2,
    'Ti': 12,
    'TiShallow': 4,
    'S': 12,
    'SShallow': 4,
    'M': 12,
    'B': 12,
    'L': 24,
    'H': 32,
    'g': 40,
    'G': 48,
    'e': 56,
}


DECODER_HIDDEN_SIZES = {
    'Deb': 16,
    'Ti': 128,
    'TiShallow': 128,
    'S': 256,
    'B': 512,
    'L': 512,
    'H': 512
}
DECODER_MLP_DIMS = {
    'Deb': 32,
    'Ti': 512,
    'TiShallow': 512,
    'S': 1024,
    'B': 2048,
    'L': 2048,
    'H': 2048
}
DECODER_NUM_LAYERS = {
    'Deb': 2,
    'Ti': 2,
    'TiShallow': 2,
    'S': 4,
    'B': 8,
    'L': 8,
    'H': 8
}
DECODER_NUM_HEADS = {
    'Deb': 2,
    'Ti': 4,
    'TiShallow': 4,
    'S': 8,
    'B': 16,
    'L': 16,
    'H': 16
}


def get_config(runlocal=''):
  """Returns the ViT experiment configuration."""

  runlocal = bool(runlocal)

  config = ml_collections.ConfigDict()
  config.experiment_name = 'electrodes-mae-vit-tiny'
  # Dataset.
  config.dataset_name = 'lsm_prod/lsm_300min_100K_unimpute'
  config.data_dtype_str = 'float32'
  config.dataset_configs = ml_collections.ConfigDict()
  config.dataset_configs.dataset = 'lsm_prod/lsm_300min_100K_unimpute'
  # config.dataset_configs.num_classes = NUM_CLASSES
  config.dataset_configs.train_split = 'train'
  config.dataset_configs.val_split = 'test'

  # NOTE: Can inject augmentations / preprocessing of dataset here
  # using config.dataset_configs.pp_train and config.dataset_configs.pp_eval.
  # Refer to scenic.dataset_lib.big_transfer.bit for how this is done on
  # Cifar10.

  config.dataset_configs.prefetch_to_device = 2
  # Shuffle_buffer_size is per host, so small-ish is ok.
  config.dataset_configs.shuffle_buffer_size = 250_000

  # Model.
  if len(VARIANT.split('/')) == 3:
    version = VARIANT.split('/')[0]  # model variant
    patch_h = VARIANT.split('/')[1]  # patch width
    patch_w = VARIANT.split('/')[2]  # patch height
  elif len(VARIANT.split('/')) == 2:
    version = VARIANT.split('/')[0]  # model variant
    patch_h = VARIANT.split('/')[1]  # patch width
    patch_w = VARIANT.split('/')[1]  # patch height
  else:
    raise ValueError(f'Invalid model variant: {VARIANT}')

  version = 'Deb' if runlocal else version
  config.model_name = 'vit_mae'
  config.model = ml_collections.ConfigDict()
  # encoder
  config.model.hidden_size = HIDDEN_SIZES[version]
  config.model.patches = ml_collections.ConfigDict()
  config.model.patches.size = [int(patch_h), int(patch_w)]
  config.model.num_heads = NUM_HEADS[version]
  config.model.mlp_dim = MLP_DIMS[version]
  config.model.num_layers = NUM_LAYERS[version]
  config.model.dropout_rate = 0.
  config.model.classifier = 'none'  # Has to be "none" for the autoencoder
  config.model.representation_size = None
  config.model.positional_embedding = 'sinusoidal_2d'
  config.model.positional_embedding_decoder = 'sinusoidal_2d'
  # decoder
  config.model.decoder_config = ml_collections.ConfigDict()
  config.model.decoder_config.hidden_size = DECODER_HIDDEN_SIZES[version]
  config.model.decoder_config.mlp_dim = DECODER_MLP_DIMS[version]
  config.model.decoder_config.num_layers = DECODER_NUM_LAYERS[version]
  config.model.decoder_config.num_heads = DECODER_NUM_HEADS[version]
  config.model.decoder_config.dropout_rate = 0.
  config.model.decoder_config.attention_dropout_rate = 0.

  config.masked_feature_loss = ml_collections.ConfigDict()
  config.masked_feature_loss.targets_type = 'rgb'
  config.masked_feature_loss.token_mask_probability = TOKEN_MASK_PROB
  config.masked_feature_loss.loss_only_masked_tokens = True
  config.masked_feature_loss.loss_type = 'squared'  # 'squared' or 'absolute'

  # Training.
  config.trainer_name = 'multimask_trainer'
  config.batch_size = 8 if runlocal else 1024
  config.num_training_steps = NUM_TRAIN_STEPS
  config.log_eval_steps = 100
  config.log_summary_steps = 100
  config.rng_seed = 42
  sched = ml_collections.ConfigDict()
  sched.re = '(.*)'
  sched.lr_configs = ml_collections.ConfigDict()
  sched.lr_configs.learning_rate_schedule = 'compound'
  sched.lr_configs.factors = 'constant * cosine_decay * linear_warmup'
  sched.lr_configs.total_steps = NUM_TRAIN_STEPS
  sched.lr_configs.steps_per_cycle = sched.lr_configs.total_steps
  sched.lr_configs.warmup_steps = 2000
  sched.lr_configs.base_learning_rate = LRS[0]
  config.schedule = ml_collections.ConfigDict({'all': sched})

  # *Single* optimizer.
  optim = ml_collections.ConfigDict()
  optim.optax_name = 'scale_by_adam'
  # optim.optax = dict(mu_dtype='bfloat16')
  optim.optax_configs = ml_collections.ConfigDict(
      {  # Optimizer settings.
          'b1': 0.9,
          'b2': 0.999,
      })
  config.optax = dict(mu_dtype='bfloat16')
  optim.max_grad_norm = 1.0

  optim.weight_decay = 1e-4
  optim.weight_decay_decouple = True
  config.optimizer = optim

  # Fewshot.
  # TODO(girishvn): This needs to be adapted to electrode dataset
  config.fewshot = common_fewshot.get_config(
      batch_size=config.batch_size
  )
  config.fewshot.datasets = {}
  config.fewshot.walk_first = ()
  config.fewshot.representation_layer = 'pre_logits'
  config.fewshot.log_eval_steps = 1000

  # Logging.
  config.write_summary = True
  config.xprof = True  # Profile using xprof.
  config.checkpoint = True  # Do checkpointing.
  config.checkpoint_steps = 1000
  config.debug_train = False  # Debug mode during training.
  config.debug_eval = False  # Debug mode during eval.

  # BEGIN GOOGLE-INTERNAL
  if runlocal:
    # Current implementation fails with UPTC.
    config.count_flops = False
  # END GOOGLE-INTERNAL

  return config


# BEGIN GOOGLE-INTERNAL
def get_hyper(hyper):
  """Defines the hyper-parameters sweeps for doing grid search."""
  return hyper.product([
      hyper.sweep('config.schedule.all.lr_configs.base_learning_rate', LRS),
  ])
config = get_config(True)
config.trainer_name

In [None]:
#@title Initialize training states

# Adapted from google3/third_party/py/scenic/projects/multimask/trainer.py.

model_cls = get_model_cls('vit_mae_single_channel')
rng = jax.random.PRNGKey(config.rng_seed)
data_rng, rng = jax.random.split(rng)
dataset = get_dataset(
    config, data_rng
)

lead_host = jax.process_index() == 0
# Build the loss_fn, metrics, and flax_model.
model = model_cls(config, dataset.meta_data)

# Initialize model.
rng, params_init_rng, dropout_init_rng = jax.random.split(rng, num=3)
init_rngs = {'params': params_init_rng, 'dropout': dropout_init_rng}
init_batch = next(dataset.train_iter)
(params, model_state, num_trainable_params, gflops) = (
    train_utils.initialize_model(
        model_def=model.flax_model,
        input_spec=[
            (init_batch['input_signal'].shape[1:],
             init_batch['input_signal'].dtype)
        ],
        config=config,
        rngs=init_rngs,
        train=True,  # so that masking and decoding in MAE are initialized
    )
)

# Create LR schedules and optimizer.
schedule_fns = scenic_optax.make_schedule(config.get('schedule'))
tx, _ = scenic_optax.make(config.optimizer, schedule_fns, params)
opt_state = tx.init(params)

rng, train_rng = jax.random.split(rng)

# Create chrono class to track and store training statistics and metadata:
chrono = train_utils.Chrono()

train_state = train_utils.TrainState(
    global_step=0,
    opt_state=opt_state,
    tx=tx,
    params=params,
    model_state=model_state,
    rng=train_rng,
    metadata={'chrono': chrono.save()},
)
start_step = train_state.global_step
chrono.load(train_state.metadata['chrono'])
train_state = train_state.replace(metadata={})

# Replicate the optimzier, state, and rng.
train_state = jax_utils.replicate(train_state)
del params  # Do not keep a copy of the initial params.

# Calculate the total number of training steps.
# TODO(adosovitskiy): get rid of epochs?
total_steps, steps_per_epoch = train_utils.get_num_training_steps(
    config, dataset.meta_data
)

train_step_pmapped = jax.pmap(
    functools.partial(
        trainer.train_step,
        flax_model=model.flax_model,
        loss_fn=model.loss_function,
        lr_fns={name: lr_fn for _, name, (lr_fn, _) in schedule_fns},
        metrics_fn=model.get_metrics_fn('train'),
        config=config,
        debug=config.debug_train,
    ),
    axis_name='batch',
    # We can donate both buffers of train_state and train_batch.
    donate_argnums=(0, 1),
)
eval_step_pmapped = jax.pmap(
    functools.partial(
        trainer.eval_step,
        flax_model=model.flax_model,
        metrics_fn=model.get_metrics_fn('validation'),
        config=config,
        debug=config.debug_eval,
    ),
    axis_name='batch',
    # We can donate the eval_batch's buffer.
    donate_argnums=(1,),
)

if 'fewshot' in config:
  representation_fn_partial = functools.partial(
      trainer.representation_fn,
      flax_model=model.flax_model,
      representation_layer=config.fewshot.representation_layer,
  )

  fewshotter = fewshot_utils.FewShotEvaluator(
      representation_fn_partial, config.fewshot
  )

log_eval_steps = config.get('log_eval_steps')
if not log_eval_steps:
  raise ValueError("'log_eval_steps' should be specified in the config.")
checkpoint_steps = config.get('checkpoint_steps') or log_eval_steps
log_summary_steps = config.get('log_summary_steps') or log_eval_steps

train_metrics, extra_training_logs = [], []
train_summary, eval_summary = None, None

chrono.inform(start_step, total_steps, config.batch_size, steps_per_epoch)
logging.info('Starting training loop at step %d.', start_step + 1)


def write_note(note):
  if lead_host:
    platform.work_unit().set_notes(note)


hooks = []


def evaluate(
    train_state: train_utils.TrainState,
    step: int,
    valid_iter: Iterator[Batch],
    num_valid_ex: int,
) -> Dict[str, Any]:
  eval_summary = {}
  if not isinstance(valid_iter, dict):  # Only on validation set.
    valid_iter, num_valid_ex = {'valid': valid_iter}, {'valid': num_valid_ex}

  for val_name, val_iter in valid_iter.items():
    num_ex = num_valid_ex[val_name]
    # Ceil rounding such that we include the last incomplete batch.
    eval_batch_size = config.get('eval_batch_size', config.batch_size)
    total_eval_steps = int(np.ceil(num_ex / eval_batch_size))
    steps_per_eval = config.get('steps_per_eval') or total_eval_steps
    eval_metrics = []
    for _ in range(steps_per_eval):
      eval_batch = next(val_iter)
      e_metrics, _ = eval_step_pmapped(train_state, eval_batch)
      eval_metrics.append(train_utils.unreplicate_and_get(e_metrics))
      eval_summary[val_name] = eval_metrics
  return eval_summary


def process_valid_summary(eval_summary):
  mae_all = []
  mae_masked_all = []
  mse_all = []
  mse_masked_all = []
  for batch_eval in eval_summary['valid']:
    mae = float(batch_eval['mean_absolute_error_all'][0])
    mae_masked = float(batch_eval['mean_absolute_error_masked'][0])
    mse = float(batch_eval['mean_squared_error_all'][0])
    mse_masked = float(batch_eval['mean_squared_error_masked'][0])
    mae_all.append(mae)
    mae_masked_all.append(mae_masked)
    mse_all.append(mse)
    mse_masked_all.append(mse_masked)
  return (
      sum(mae_all) / len(mae_all),
      sum(mae_masked_all) / len(mae_masked_all),
      sum(mse_all) / len(mse_all),
      sum(mse_masked_all) / len(mse_masked_all),
  )


def smooth_data(data, smoothing_factor=0.9):
  smoothed_data = []
  for i, value in enumerate(data):
    if i == 0:
      smoothed_data.append(value)
    else:
      smoothed_value = (
          smoothing_factor * smoothed_data[-1] + (1 - smoothing_factor) * value
      )
      smoothed_data.append(smoothed_value)
  return smoothed_data


def plot_steps(data, mode='train', smoothing_factor=None):
  # Create a DataFrame from the dictionary
  df = pd.DataFrame.from_dict(
      data,
      orient='index',
      columns=['mae_all', 'mae_masked_all', 'mse_all', 'mse_masked_all'],
  )
  df.reset_index(inplace=True)
  df.rename(columns={'index': 'steps'}, inplace=True)

  # Apply smoothing if specified
  if smoothing_factor is not None:
    df['mae_all'] = smooth_data(df['mae_all'].tolist(), smoothing_factor)
    df['mae_masked_all'] = smooth_data(
        df['mae_masked_all'].tolist(), smoothing_factor
    )
    df['mse_all'] = smooth_data(df['mse_all'].tolist(), smoothing_factor)
    df['mse_masked_all'] = smooth_data(
        df['mse_masked_all'].tolist(), smoothing_factor
    )

  # Create the plot
  fig, axes = plt.subplots(1, 4, figsize=(20, 5))
  plot_titles = [
      f'MAE All ({mode})',
      f'MAE Masked All ({mode})',
      f'MSE All ({mode})',
      f'MSE Masked All ({mode})',
  ]

  sns.lineplot(ax=axes[0], x='steps', y='mae_all', data=df)
  axes[0].set_title(plot_titles[0])
  axes[0].set_xlabel('Steps')
  axes[0].set_ylabel('MAE All')

  sns.lineplot(ax=axes[1], x='steps', y='mae_masked_all', data=df)
  axes[1].set_title(plot_titles[1])
  axes[1].set_xlabel('Steps')
  axes[1].set_ylabel('MAE Masked All')

  sns.lineplot(ax=axes[2], x='steps', y='mse_all', data=df)
  axes[2].set_title(plot_titles[2])
  axes[2].set_xlabel('Steps')
  axes[2].set_ylabel('MSE All')

  sns.lineplot(ax=axes[3], x='steps', y='mse_masked_all', data=df)
  axes[3].set_title(plot_titles[3])
  axes[3].set_xlabel('Steps')
  axes[3].set_ylabel('MSE Masked All')

  plt.tight_layout()
  plt.show()

In [None]:
#@title Run training

# Adapted from google3/third_party/py/scenic/projects/multimask/trainer.py.
# and google3/third_party/py/scenic/projects/multimask/main.py

train_losses = {}
validation_losses = {}
t_start = time.time()

flax.config.update('flax_use_orbax_checkpointing', False)
for step in tqdm.tqdm(range(start_step + 1, total_steps + 1)):
  with jax.profiler.StepTraceAnnotation('train', step_num=step):
    train_batch = next(dataset.train_iter)
    train_state, t_metrics, t_logs = train_step_pmapped(
        train_state, train_batch
    )
    # This will accumulate metrics in TPU memory up to the point that we log
    # them. This is no problem for small metrics but may be a problem for
    # large (e.g. segmentation) metrics. An alternative is to set
    # `log_summary_steps` to a small number, or to use
    # `train_utils.unreplicate_and_get` here instead of right before writing
    # summaries, but that means in each step, we have data transfer between
    # tpu and host, which might slow down the training.
    train_metrics.append(t_metrics)
    # Additional training logs: learning rate:
    t_logs = jax.tree_util.tree_map(jax_utils.unreplicate, t_logs)
    extra_training_logs.append(t_logs)
    mae_all = round(float(t_metrics['mean_absolute_error_all'][0][0]), 2)
    mae_masked = round(float(t_metrics['mean_absolute_error_masked'][0][0]), 2)
    mse_all = round(float(t_metrics['mean_squared_error_all'][0][0]), 2)
    mse_masked = round(float(t_metrics['mean_squared_error_masked'][0][0]), 2)
    if step % 100 == 0:
      t_current = time.time()
      t_step = (t_current - t_start) / step
      print(
          'step',
          step,
          'mae_all:',
          mae_all,
          'mae_masked:',
          mae_masked,
          'mse_all:',
          mse_all,
          'mse_masked:',
          mse_masked,
          'time per step:',
          f'{t_step}s\n',
          sep=' ',
      )
    train_losses[step] = (mae_all, mae_masked, mse_all, mse_masked)
  # Quick indication that training is happening.
  logging.log_first_n(logging.INFO, 'Finished training step %d.', 5, step)
  for h in hooks:
    h(step)
  ################### EVALUATION #######################
  if (step % log_eval_steps == 1) or (step == total_steps):
    # chrono.pause(wait_for=(train_state.params))
    train_state = train_utils.sync_model_state_across_replicas(train_state)
    eval_summary = evaluate(
          train_state,
          step,
          dataset.valid_iter,
          dataset.meta_data['num_eval_examples'],
      )
    mae_all, mae_masked_all, mse_all, mse_masked_all = process_valid_summary(eval_summary)
    validation_losses[step] = (mae_all, mae_masked_all, mse_all, mse_masked_all)

# Time metrics
t_end = time.time()
t_total = (t_end - t_start)
print(f'total training and eval time: {t_total}\n')

%matplotlib inline
plot_steps(validation_losses, 'validation')
plot_steps(train_losses, 'train', smoothing_factor=0.9)