In [None]:
# @title Imports

import io
import functools
from typing import Any, Callable, Dict, Iterator, Tuple, Optional, Type, Union
import time
from collections import Counter
import copy

from absl import logging
from clu import metric_writers
from clu import periodic_actions
from clu import platform

import flax
from flax import jax_utils
import flax.linen as nn
import jax
import jax.numpy as jnp
import jax.profiler

import pandas as pd
import ml_collections
import numpy as np
import optax
import tensorflow as tf
import tensorflow_datasets as tfds

from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

import matplotlib
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
%matplotlib inline

from colabtools import adhoc_import
with adhoc_import.Google3():
  from scenic.dataset_lib import dataset_utils
  from scenic.google.xm import xm_utils
  from scenic.model_lib.base_models import base_model
  from scenic.model_lib.base_models import model_utils
  from scenic.model_lib.layers import nn_ops
  from scenic.model_lib.layers import nn_layers
  from scenic.projects.baselines import vit
  from scenic.train_lib import optax as scenic_optax
  from scenic.train_lib import pretrain_utils
  from scenic.train_lib import train_utils

  from scenic.projects.multimask.models import model_utils as mm_model_utils

  from google3.experimental.largesensormodels.scenic.datasets import dataset_constants
  from google3.experimental.largesensormodels.scenic.datasets import lsm_activity_subset_dataset
  from google3.experimental.largesensormodels.scenic.datasets import lsm_mood_vs_activity_dataset
  from google3.experimental.largesensormodels.scenic.datasets import lsm_tiny_dataset
  from google3.experimental.largesensormodels.scenic.datasets import lsm_combined_pretrain_dataset
  from google3.experimental.largesensormodels.scenic.datasets import lsm_fewshot_mood_vs_activity_dataset
  from google3.experimental.largesensormodels.scenic.datasets import lsm_fewshot_remapped_activity_dataset

  from google3.experimental.largesensormodels.scenic.models import lsm_vit as lsm_vit_mae
  from google3.experimental.largesensormodels.scenic.models.lsm_vit_utils import model_constants
  from google3.experimental.largesensormodels.scenic.models.lsm_vit_utils import model_utils as lsm_model_utils
  from google3.experimental.largesensormodels.scenic.trainers import lsm_mae_trainer

  from google3.pyglib import gfile


Batch = Dict[str, jnp.ndarray]
MetricFn = Callable[
    [jnp.ndarray, jnp.ndarray, Dict[str, jnp.ndarray]],
    Dict[str, Tuple[float, int]],
]
LossFn = Callable[
    [jnp.ndarray, Batch, Optional[jnp.ndarray], jnp.ndarray], float
]
LrFns = Dict[str, Callable[[jnp.ndarray], jnp.ndarray]]
Patch = Union[Tuple[int, int], Tuple[int, int, int]]


In [None]:
# @title Sample Config (Base ViT MAE)

r"""A config to train a Base ViT MAE on LSM dataset.

Forked from google3/third_party/py/scenic/projects/multimask/configs/mae_cifar10_tiny.py
"""

from typing import Optional
import ml_collections
from google3.experimental.largesensormodels.scenic.models.lsm_vit_utils import model_constants


# To set constants.
# 1) Dataset variables.
DATASET_NAME = 'lsm_300min_pretraining_165K_n10'
CACHE_DATASET = True
TRAIN_DATA_SIZES = [1321235]
USE_DATETIME_FEATURES = False
USE_TRAIN_AUGMENTATIONS = [True]
TRAIN_AUGMENTATIONS = ['stretch', 'flip', 'noise']
SHUFFLE_SEED = 42

# 2) Training / eval variables.
BATCH_SIZE = 8
NUM_TRAIN_STEPS = 50000
LRS = [5e-3]
WEIGHT_DECAYS = [1e-4]

# 3) Logging variables.
LOG_EVAL_SUMMARY_STEPS = 500  # STEPS_PER_EPOCH
LOG_CHECKPOINT_STEPS = 100  # LOG_EVAL_SUMMARY_STEPS * 5
MAX_NUM_CHECKPOINTS = int(NUM_TRAIN_STEPS / LOG_CHECKPOINT_STEPS)

# Model variant / patch H (time steps) / patch W (features)
VARIANT = 'B/10/5'
TOKEN_MASK_PROB = 'constant_0.8'
LOSS_ONLY_MASKED_TOKENS = True

# Downstream Tasks.
# Imputation and forecast eval
RECONSTRUCTION_HORIZONS = [0.1, 0.2, 0.4]

# Linear probe eval.
LINEAR_PROBE_USE_TRAIN_AUGMENTATIONS = False
LINEAR_PROBE_TRAIN_AUGMENTATIONS = ['noise']


def get_config(runlocal=''):
  """Returns the ViT experiment configuration."""

  runlocal = bool(runlocal)

  # Experiment.
  config = ml_collections.ConfigDict()
  config.experiment_name = f'electrodes-mae-{DATASET_NAME}'
  config.dataset_name = f'lsm_prod/{DATASET_NAME}'
  config.shuffle_seed = SHUFFLE_SEED

  # Dataset.
  config.data_dtype_str = 'float32'
  config.dataset_configs = ml_collections.ConfigDict()
  config.dataset_configs.dataset = f'lsm_prod/{DATASET_NAME}'
  config.dataset_configs.num_classes = None
  config.dataset_configs.train_split = 'train'  # train data split
  config.dataset_configs.train_num_samples = TRAIN_DATA_SIZES[0]  # train sample
  # eval data split - note: this split is used for validation and test.
  config.dataset_configs.eval_split = 'test[:64]' if runlocal else 'test'
  config.dataset_configs.cache_dataset = CACHE_DATASET
  config.dataset_configs.prefetch_to_device = 2
  # Shuffle_buffer_size is per host, so small-ish is ok.
  config.dataset_configs.shuffle_buffer_size = 250_000

  # Model.
  if len(VARIANT.split('/')) == 3:
    version = VARIANT.split('/')[0]  # model variant
    patch_h = VARIANT.split('/')[1]  # patch width
    patch_w = VARIANT.split('/')[2]  # patch height
  elif len(VARIANT.split('/')) == 2:
    version = VARIANT.split('/')[0]  # model variant
    patch_h = VARIANT.split('/')[1]  # patch width
    patch_w = VARIANT.split('/')[1]  # patch height
  else:
    raise ValueError(f'Invalid model variant: {VARIANT}')

  version = 'Deb' if runlocal else version
  config.model_name = 'lsm_vit_mae'
  config.model = ml_collections.ConfigDict()
  # encoder
  config.model.hidden_size = model_constants.HIDDEN_SIZES[version]
  config.model.patches = ml_collections.ConfigDict()
  config.model.patches.size = tuple([int(patch_h), int(patch_w)])
  config.model.num_heads = model_constants.NUM_HEADS[version]
  config.model.mlp_dim = model_constants.MLP_DIMS[version]
  config.model.num_layers = model_constants.NUM_LAYERS[version]
  config.model.dropout_rate = 0.0
  config.model.classifier = 'none'  # Has to be "none" for the autoencoder
  config.model.representation_size = None
  config.model.positional_embedding = 'sinusoidal_2d'
  config.model.positional_embedding_decoder = 'sinusoidal_2d'
  # decoder
  config.model.decoder_config = ml_collections.ConfigDict()
  config.model.decoder_config.hidden_size = (
      model_constants.DECODER_HIDDEN_SIZES[version]
  )
  config.model.decoder_config.mlp_dim = model_constants.DECODER_MLP_DIMS[
      version
  ]
  config.model.decoder_config.num_layers = model_constants.DECODER_NUM_LAYERS[
      version
  ]
  config.model.decoder_config.num_heads = model_constants.DECODER_NUM_HEADS[
      version
  ]
  config.model.decoder_config.dropout_rate = 0.0
  config.model.decoder_config.attention_dropout_rate = 0.0

  config.masked_feature_loss = ml_collections.ConfigDict()
  config.masked_feature_loss.targets_type = 'rgb'
  config.masked_feature_loss.token_mask_probability = TOKEN_MASK_PROB
  config.masked_feature_loss.loss_only_masked_tokens = LOSS_ONLY_MASKED_TOKENS
  config.masked_feature_loss.loss_type = 'squared'  # 'squared' or 'absolute'

  # Datetime features.
  config.use_datetime_features = USE_DATETIME_FEATURES

  # Training.
  config.trainer_name = 'lsm_mae_trainer'
  config.batch_size = 8 if runlocal else BATCH_SIZE
  config.num_training_steps = NUM_TRAIN_STEPS
  config.log_eval_steps = LOG_EVAL_SUMMARY_STEPS
  config.log_summary_steps = LOG_EVAL_SUMMARY_STEPS
  config.rng_seed = 42
  config.use_train_augmentations = USE_TRAIN_AUGMENTATIONS[0]
  config.train_augmentations = TRAIN_AUGMENTATIONS
  sched = ml_collections.ConfigDict()
  sched.re = '(.*)'
  sched.lr_configs = ml_collections.ConfigDict()
  sched.lr_configs.learning_rate_schedule = 'compound'
  sched.lr_configs.factors = 'constant * cosine_decay * linear_warmup'
  sched.lr_configs.total_steps = NUM_TRAIN_STEPS
  sched.lr_configs.steps_per_cycle = sched.lr_configs.total_steps
  sched.lr_configs.warmup_steps = int(NUM_TRAIN_STEPS * 0.05)
  sched.lr_configs.base_learning_rate = LRS[0]
  config.schedule = ml_collections.ConfigDict({'all': sched})

  # *Single* optimizer.
  optim = ml_collections.ConfigDict()
  optim.optax_name = 'scale_by_adam'
  # optim.optax = dict(mu_dtype='bfloat16')
  optim.optax_configs = ml_collections.ConfigDict({  # Optimizer settings.
      'b1': 0.9,
      'b2': 0.95,
  })
  config.optax = dict(mu_dtype='bfloat16')
  optim.max_grad_norm = 1.0

  optim.weight_decay = WEIGHT_DECAYS[0]
  optim.weight_decay_decouple = True
  config.optimizer = optim

  # Downstream Tasks.

  # 2) Reconstruction Eval Tasks (Forecast and Imputation).
  config.forecast = ml_collections.ConfigDict()
  config.forecast.horizons = RECONSTRUCTION_HORIZONS
  config.imputation = ml_collections.ConfigDict()
  config.imputation.horizons = RECONSTRUCTION_HORIZONS

  # Logging.
  config.write_summary = True
  config.xprof = True  # Profile using xprof.
  config.checkpoint = True  # Do checkpointing.
  config.checkpoint_steps = LOG_CHECKPOINT_STEPS
  config.debug_train = False  # Debug mode during training.
  config.debug_eval = False  # Debug mode during eval.
  config.max_checkpoints_to_keep = MAX_NUM_CHECKPOINTS
  # BEGIN GOOGLE-INTERNAL
  if runlocal:
    # Current implementation fails with UPTC.
    config.count_flops = False
  # END GOOGLE-INTERNAL

  return config


In [None]:
# @title Parameter Equivalence Function

# Load params from the restored_train_state.
def compare_params(params1, params2, space=0):
  for key, value in params1.items():

    if isinstance(value, dict):
      # If value is a dict, ensure the second dict has this key and recurse.
      if key in params2 and isinstance(params2[key], dict):
        if space == 0: print()
        print(f"{' '*space*2}{key}:")
        compare_params(value, params2[key], space + 1)

    else:
      if key in params2:
        if params1[key].shape != params2[key].shape:
          print(f"{' '*space*2}{key}: params1 shape', {params1[key].shape}")
          print(f"{' '*space*2}{key}: params1 shape', {params2[key].shape}")
        elif not jnp.array_equal(params1[key], params2[key]):
          print(f"{' '*space*2}{key}: NOT EQUAL")
        # else:
        #   print(f"{' '*space*2}{key}: EQUAL")


In [None]:
# @title Model Init Function

def init_model(config, input_shape, dtype=jnp.float32):
  """Initialize model."""
  model_cls = lsm_vit_mae.ViTMAESingleChannelModel
  model = model_cls(config, {})
  rng = jax.random.PRNGKey(config.rng_seed)
  rng, params_init_rng, dropout_init_rng = jax.random.split(rng, num=3)
  init_rngs = {'params': params_init_rng, 'dropout': dropout_init_rng}
  (params, model_state, _, _) = (
      train_utils.initialize_model(
          model_def=model.flax_model,
          input_spec=[(
              input_shape,
              dtype,
          )],
          config=config,
          rngs=init_rngs,
          train=True,  # so that masking and decoding in MAE are initialized
      )
  )
  # NOTE: Do not delete init batch, as it may be used to init downstream models.

  # Create LR schedules and optimizer.
  schedule_fns = scenic_optax.make_schedule(config.get('schedule'))
  tx, _ = scenic_optax.make(config.optimizer, schedule_fns, params)
  opt_state = tx.init(params)

  # Split rng for train state.
  rng, train_rng = jax.random.split(rng)  # pylint: disable=unused-variable

  # Create chrono class to track and store training statistics and metadata:
  chrono = train_utils.Chrono()

  # Create new / empty train_state.
  train_state = train_utils.TrainState(
      global_step=0,
      opt_state=opt_state,
      tx=tx,
      params=params,
      model_state=model_state,
      rng=train_rng,
      metadata={'chrono': chrono.save()},
  )

  return model, train_state


In [None]:
# @title XM and Config Setup

# XM job to load ckpt from
xid, wid = (134755220, 5)

# Get config
config = get_config()


## Checkpoint Restoration

We explore three methods:

1. This is the currently used method which does not fully restore the checkpoint state. The `output_projection` parameters are ignored.
2. The currently implemented loading method - used for `generative_eval`.
This method works but can be better formalized.
3. Final resulting method - this will be implemented in mae trainer

In [None]:
# @title Checkpoint 1

# Init model
model, train_state = init_model(config, input_shape=[8, 300, 30, 1])

# Init from XM (XID, WID)
restored_model_cfg, init_checkpoint_path = xm_utils.get_info_from_xmanager(
    xid, wid
)
restored_train_state = pretrain_utils.restore_pretrained_checkpoint(
    init_checkpoint_path, train_state, assert_exist=True
)
cp_restored_train_state = copy.deepcopy(restored_train_state)

# Update train state
train_state = model.init_from_train_state(  # pytype: disable=attribute-error
    train_state, restored_train_state, restored_model_cfg
)

# Check Same-ness
original_params = cp_restored_train_state.params
restored_params = train_state.params

raw_params = flax.core.unfreeze(original_params)
restored_params = flax.core.unfreeze(restored_params)
compare_params(restored_params, raw_params)


In [None]:
# @title Method 2

def flatten_restored_params(raw_params, restored_params):
  """Removes process dimension from restored params."""
  for key, value in raw_params.items():
    if isinstance(value, dict):
      # If value is a dict, ensure the second dict has this key and recurse.
      if key in restored_params and isinstance(restored_params[key], dict):
        flatten_restored_params(value, restored_params[key])
    else:
      if key in restored_params:
        if raw_params[key].shape != restored_params[key].shape:
          # restored_params[key] = value[0]
          restored_params[key] = restored_params[key][0]
          print(key)


# Init model
model, train_state = init_model(config, input_shape=[8, 300, 30, 1])

# Init from XM (XID, WID)
restored_model_cfg, init_checkpoint_path = xm_utils.get_info_from_xmanager(
    xid, wid
)
restored_train_state = pretrain_utils.restore_pretrained_checkpoint(
    init_checkpoint_path, train_state, assert_exist=True
)
cp_restored_train_state = copy.deepcopy(restored_train_state)

# Update train state
raw_params = flax.core.unfreeze(train_state.params)
restored_params = flax.core.unfreeze(restored_train_state.params)
flatten_restored_params(raw_params, restored_params)
train_state = train_state.replace(params=flax.core.freeze(restored_params))  # pytype: disable=attribute-error

# Check Same-ness
original_params = cp_restored_train_state.params
restored_params = train_state.params

raw_params = flax.core.unfreeze(original_params)
restored_params = flax.core.unfreeze(restored_params)
compare_params(restored_params, raw_params)


In [None]:
# @title Method 3

def _restore_params(params, restored_params):
  """Removes process dimension from restored params."""

  # Iterate through model parameters.
  for key, value in params.items():

    # If key value is a param dictionary.
    if isinstance(value, dict):
      if key in restored_params and isinstance(restored_params[key], dict):
        _restore_params(value, restored_params[key])  # Recurse.

    # If key value is a tensor.
    else:
      if key in restored_params:  # If key in restored params
        params_shape = params[key].shape
        restored_shape = restored_params[key].shape

        # Transferable shape (same shape).
        if params_shape == restored_shape:
          params[key] = restored_params[key]

        # Transferable shape (restored params have leading process dim).
        elif params_shape == restored_shape[1:] and restored_shape[0] == 1:
          params[key] = restored_params[key][0]

        # Non-transferable shape (different shape).
        else:
          raise ValueError(
              f'Unable to restore {key} from restored params of shape'
              f'{restored_shape} to params of shape {params_shape}'
          )


def restore_from_train_state(train_state, restored_train_state):

  # Get parameters from trainstate and unfreeze (to traverse and modify).
  params = flax.core.unfreeze(train_state.params)
  restored_params = flax.core.unfreeze(restored_train_state.params)

  # Restore params from restored_params (done in-place).
  _restore_params(params, restored_params)

  # Update train_state parameters and return.
  train_state = train_state.replace(params=flax.core.freeze(params))  # pytype: disable=attribute-error
  return train_state


# Init model
model, train_state = init_model(config, input_shape=[8, 300, 30, 1])

# Init from XM (XID, WID)
restored_model_cfg, init_checkpoint_path = xm_utils.get_info_from_xmanager(
    xid, wid
)
restored_train_state = pretrain_utils.restore_pretrained_checkpoint(
    init_checkpoint_path, train_state, assert_exist=True
)
cp_restored_train_state = copy.deepcopy(restored_train_state)

# Update train state
train_state = restore_from_train_state(train_state, restored_train_state)

# Check Same-ness
original_params = cp_restored_train_state.params
restored_params = train_state.params

raw_params = flax.core.unfreeze(original_params)
restored_params = flax.core.unfreeze(restored_params)
compare_params(restored_params, raw_params)
