## LSM Checkpoint Loading and Sample Inference Visualization.

### Information:
- Adapted from the Scenic framework.
- Colab Kernel: `Electrodes Colab A/B/C`
- Dataset: `Electrodes`

- Grants command for Access on Demand (AoD):
https://grants.corp.google.com/#/grants?request=20h%2Fchr-ards-electrodes-deid-colab-jobs&reason=b%2F314799341

### About this notebook:
- Given a model cofig, restore a model from a checkpoint, and visualize inference on a batch (8) from the `validation` set.
- The plots are created per sample in the batch. 1) The original input 2) The mask applied 3) The predicted reconstruction.
- Note that the `MAE` and `MSE` on top of the reconstructed plot refer to the error between the original input image and the predicted output image (NOT the training loss, or the MAE training method).
- This notebook provides examples of many patching strategies (`random`, `imputation`, `forecast`).

In [None]:
# @title Imports

import functools
from typing import Any, Callable, Dict, Iterator, Tuple, Optional, Type, Union

from absl import logging
from clu import metric_writers
from clu import periodic_actions
from clu import platform

import flax
from flax import jax_utils
import flax.linen as nn
import jax
import jax.numpy as jnp
import jax.profiler

import ml_collections
import numpy as np
import optax
import tensorflow as tf
import tensorflow_datasets as tfds

import matplotlib
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
%matplotlib inline

from colabtools import adhoc_import
with adhoc_import.Google3():
  from scenic.dataset_lib import dataset_utils
  from scenic.google.xm import xm_utils
  from scenic.model_lib.base_models import base_model
  from scenic.model_lib.base_models import model_utils
  from scenic.model_lib.layers import nn_ops
  from scenic.model_lib.layers import nn_layers
  from scenic.projects.baselines import vit
  from scenic.train_lib import optax as scenic_optax
  from scenic.train_lib import pretrain_utils
  from scenic.train_lib import train_utils

  from scenic.projects.multimask.models import model_utils as mm_model_utils

  from google3.experimental.largesensormodels.scenic.datasets import dataset_constants
  from google3.experimental.largesensormodels.scenic.datasets import lsm_tiny_dataset
  from google3.experimental.largesensormodels.scenic.models import lsm_vit as lsm_vit_mae
  from google3.experimental.largesensormodels.scenic.models.lsm_vit_utils import model_constants
  from google3.experimental.largesensormodels.scenic.models.lsm_vit_utils import model_utils as lsm_model_utils
  from google3.experimental.largesensormodels.scenic.trainers import lsm_mae_trainer

  from google3.pyglib import gfile


Batch = Dict[str, jnp.ndarray]
MetricFn = Callable[
    [jnp.ndarray, jnp.ndarray, Dict[str, jnp.ndarray]],
    Dict[str, Tuple[float, int]],
]
LossFn = Callable[
    [jnp.ndarray, Batch, Optional[jnp.ndarray], jnp.ndarray], float
]
LrFns = Dict[str, Callable[[jnp.ndarray], jnp.ndarray]]
Patch = Union[Tuple[int, int], Tuple[int, int, int]]

## Helper Functions

In [None]:
# @title Get Model Class and Trainer

def get_model_cls(model_name: str):
  """Get the model class for the Multimask project."""
  if model_name == 'lsm_vit_mae':
    return lsm_vit_mae.ViTMAESingleChannelModel
  else:
    raise ValueError(f'Unrecognized model: {model_name}.')


def get_train_fn(trainer_name):
  if trainer_name == 'lsm_mae_trainer':
    return lsm_mae_trainer.train
  else:
    raise ValueError(f'Unrecognized trainer: {trainer_name}.')

In [None]:
# @title Visualization Functions

def plot_reconstructed_image(
    original_img: jnp.ndarray,
    reconstructed_img: jnp.ndarray,
    masked_tokens: jnp.ndarray,
    step: Optional[int] = None,
    split_name: Optional[str] = None,
    traspose_img: bool = True
):
  """Plots the original image, image mask,and the reconstructed image.

  Args:
    original_img: The original image of shape [H, W, C].
    reconstructed_img: The reconstructed image of shape
      [num patches, patch size (pw*ph*C)].
    masked_tokens: masked patches of shape [num patches].
    step: The training step (int).
    split_name: The data split name (str).
  """
  img_h, img_w, img_c = original_img.shape  # original image shape
  ph, pw = config.model.patches.size  # get the patch shape
  nh, nw = (img_h // ph, img_w // pw)  # get the number of patches

  # Reconstruct predicted img
  r1 = jnp.reshape(reconstructed_img, shape=(nh, nw, ph, pw, img_c))
  r2 = jnp.transpose(r1, (0, 2, 1, 3, 4))
  r3 = jnp.reshape(r2, shape=(nh*ph, nw*pw, img_c))

  # Construct patched mask img
  p1 = jnp.ones((nh, ph, nw, pw, img_c))
  p2 = jnp.transpose(p1, (0, 2, 1, 3, 4))
  p3 = jnp.reshape(p2, shape=(nh*nw, ph*pw, img_c))
  patched_ones = jnp.reshape(p3, shape=(nh*nw, ph*pw*img_c))
  # Apply mask
  weights_broadcast = jax.lax.broadcast_in_dim(
    masked_tokens,
    shape=patched_ones.shape,
    broadcast_dimensions=tuple(range(masked_tokens.ndim)),
  )
  img_mask = model_utils.apply_weights(patched_ones, weights_broadcast)
  # Reconstruct mask image
  rm1 = jnp.reshape(img_mask, shape=(nh, nw, ph, pw, img_c))
  rm2 = jnp.transpose(rm1, (0, 2, 1, 3, 4))
  mask_img_negative = jnp.reshape(rm2, shape=(nh*ph, nw*pw, img_c))
  mask_img = 1 - mask_img_negative

  # Calculate metrics
  mse = np.mean(np.square(original_img - r3))
  mae = np.mean(np.abs(original_img - r3))

  # Plot
  original_title = "Original"
  reconstructed_title = "Reconstructed"
  masked_title = "Mask"
  if split_name is not None:
    original_title = original_title + f' ({split_name})'
    # reconstructed_title = reconstructed_title + f' ({split_name})'
  if step is not None:
    original_title = original_title + f' Step: {step}'
    # reconstructed_title = reconstructed_title + f': {step}'

  reconstructed_title = reconstructed_title + f' MAE: {mae:.2f}, MSE: {mse:.2f}'

  if traspose_img:
    original_plot = jnp.transpose(original_img, (1, 0, 2))
    mask_plot = jnp.transpose(mask_img, (1, 0, 2))
    reconstructed_plot = jnp.transpose(r3, (1, 0, 2))
  else:
    original_plot = original_img
    mask_plot = mask_img
    reconstructed_plot = r3

  vmin = jnp.min(original_plot)
  vmax = jnp.max(original_plot)

  fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(20, 8))
  ax[0].imshow(original_plot, vmin=vmin, vmax=vmax)
  ax[0].set_title(original_title)

  ax[1].imshow(mask_plot)
  ax[1].set_title(masked_title)

  ax[2].imshow(reconstructed_plot, vmin=vmin, vmax=vmax)
  ax[2].set_title(reconstructed_title)
  plt.show('\n')

  return {
      'img': original_plot,
      'img_title': original_title,
      'mask': mask_plot,
      'mask_title': masked_title,
      'reconstructed': reconstructed_plot,
      'reconstructed_title': reconstructed_title,
  }




In [None]:
# @title Checkpoint Restoration and Sample Inference Functions

def run_eval_from_ckpt(
    rng, config, model_cls, dataset, workdir, writer=None, step=None
):

  def visualize_eval(
      train_state: train_utils.TrainState,
      step: int,
      valid_iter: Iterator[Batch],
      num_valid_ex: int,
      plot_sample: bool = False,
      plot_sample_batch: bool = False,
      rng: jax.random.PRNGKey = 42,
  ) -> Dict[str, Any]:
    """Run evaluation over validation sets.
    Forked from google3/third_party/py/scenic/projects/multimask/trainer.py
    """

    # Set up random seed for plotting random sample, and generating random masks.
    plot_rng, eval_rng = jax.random.split(rng)
    eval_batch_size = config.get('eval_batch_size', config.batch_size)
    plot_dict_list = []

    if not isinstance(valid_iter, dict):  # Only on validation set.
      valid_iter, num_valid_ex = {'valid': valid_iter}, {'valid': num_valid_ex}

    for val_name, val_iter in valid_iter.items():
      eval_batch = next(val_iter)  # get eval batch
      keys = jax.random.split(eval_rng, jax.process_count() + 1)
      eval_rng = keys[0]
      eval_step_rng = keys[1:]
      # e_metrics, e_logits, e_aux = eval_step_pmapped(train_state, eval_batch, rng=eval_step_rng)
      e_metrics, e_logits, e_aux = eval_step_pmapped(train_state, eval_batch)

      masked_tokens = e_aux['token_mask']
      for j in jnp.arange(eval_batch_size):
        img_original = eval_batch['input_signal'][0, j, :, :, :]
        img_reconstructed = e_logits[0, j, :, :]
        img_mask = masked_tokens[0, j]
        plot_dict = plot_reconstructed_image(
            original_img=img_original,
            reconstructed_img=img_reconstructed,
            masked_tokens=img_mask,
            step=step,
            split_name=val_name,
        )
        plot_dict_list.append(plot_dict)

    return plot_dict_list

  lead_host = jax.process_index() == 0
  # Build the loss_fn, metrics, and flax_model.
  model = model_cls(config, dataset.meta_data)

  # Initialize model.
  rng, params_init_rng, dropout_init_rng = jax.random.split(rng, num=3)
  init_rngs = {'params': params_init_rng, 'dropout': dropout_init_rng}
  init_batch = next(dataset.train_iter)
  (params, model_state, num_trainable_params, gflops) = (
      train_utils.initialize_model(
          model_def=model.flax_model,
          input_spec=[
              (init_batch['input_signal'].shape[1:], init_batch['input_signal'].dtype)
          ],
          config=config,
          rngs=init_rngs,
          train=True,  # so that masking and decoding in MAE are initialized
      )
  )

  # Get param count
  param_count = sum(x.size for x in jax.tree.leaves(params))
  print(f'\nModel Parameter Count {param_count}\n')

  # Create LR schedules and optimizer.
  schedule_fns = scenic_optax.make_schedule(config.get('schedule'))
  tx, _ = scenic_optax.make(config.optimizer, schedule_fns, params)
  opt_state = tx.init(params)

  rng, train_rng = jax.random.split(rng)

  # Create chrono class to track and store training statistics and metadata:
  chrono = train_utils.Chrono()

  train_state = train_utils.TrainState(
      global_step=0,
      opt_state=opt_state,
      tx=tx,
      params=params,
      model_state=model_state,
      rng=train_rng,
      metadata={'chrono': chrono.save()},
  )
  start_step = train_state.global_step

  # If a checkpoint exists in the working directory.
  if config.checkpoint:
    train_state, start_step = train_utils.restore_checkpoint(
        workdir, train_state, step=step
    )
    if start_step != 0:
      print(f'Restoring checkpoint from train step: {start_step}\n\n')
  chrono.load(train_state.metadata['chrono'])
  train_state = train_state.replace(metadata={})

  # If no checkpoint in working dir and
  if (
      start_step == 0  # Which means "no" checkpoint is restored!
      and config.get('init_from') is not None
  ):
    restored_model_cfg = config.init_from.get('model_config')
    init_checkpoint_path = config.init_from.get('checkpoint_path')
    # BEGIN GOOGLE-INTERNAL
    if config.init_from.get('xm'):
      xid, wid = config.init_from.get('xm')
      (restored_model_cfg, init_checkpoint_path) = (
          xm_utils.get_info_from_xmanager(xid, wid)
      )
    # END GOOGLE-INTERNAL
    checkpoint_format = config.init_from.get('checkpoint_format', 'scenic')
    if init_checkpoint_path is not None:
      if checkpoint_format == 'scenic':
        restored_train_state = pretrain_utils.restore_pretrained_checkpoint(
            init_checkpoint_path, train_state, assert_exist=True
        )
        # Load params from the init_model.
        train_state = model.init_from_train_state(  # pytype: disable=attribute-error
            train_state, restored_train_state, restored_model_cfg
        )
        del restored_train_state
      else:
        raise ValueError(f'Unsupported checkpoint format: {checkpoint_format}')

  # Replicate the optimzier, state, and rng.
  train_state = jax_utils.replicate(train_state)
  del params  # Do not keep a copy of the initial params.

  # Calculate the total number of training steps.
  # TODO(adosovitskiy): get rid of epochs?
  total_steps, steps_per_epoch = train_utils.get_num_training_steps(
      config, dataset.meta_data
  )

  eval_step_pmapped = jax.pmap(
      functools.partial(
          lsm_mae_trainer.eval_step,
          flax_model=model.flax_model,
          metrics_fn=model.get_metrics_fn('validation'),
          config=config,
          debug=config.debug_eval,
      ),
      axis_name='batch',
      # We can donate the eval_batch's buffer.
      donate_argnums=(1,),
  )


  chrono.inform(start_step, total_steps, config.batch_size, steps_per_epoch)
  step = start_step  # step of restored checkpoint
  flax.config.update('flax_use_orbax_checkpointing', False)
  ################### EVALUATION #######################
  train_state = train_utils.sync_model_state_across_replicas(train_state)
  rng, eval_rng = jax.random.split(rng)
  plot_dict_list = visualize_eval(
      train_state,
      step,
      dataset.valid_iter,
      dataset.meta_data['num_val_examples'],
      plot_sample=True,
      plot_sample_batch=True,
      rng=eval_rng,
  )

  return plot_dict_list

## Sample Inference and Patching Examples

In [None]:
# @title Base Config

r"""A config to train a Base ViT MAE on LSM dataset.

Forked from google3/third_party/py/scenic/projects/multimask/configs/mae_cifar10_tiny.py

To run on XManager:
gxm third_party/py/scenic/google/xm/launch_xm.py -- \
--binary //experimental/largesensormodels/scenic:main \
--config=experimental/largesensormodels/scenic/configs/mae_lsm_base.py \
--platform=vlp_4x8 \
--exp_name=lsm_mae_tier2_base_10_5_res \
--workdir=/cns/dz-d/home/xliucs/lsm/xm/\{xid\} \
--xm_resource_alloc=group:mobile-dynamic/h2o-ai-gqm-quota \
--scheduling_time_quantum=2d \
--priority=200

To run locally:
./third_party/py/scenic/google/runlocal.sh \
--uptc="" \
--binary=//experimental/largesensormodels/scenic:main \
--config=$(pwd)/experimental/largesensormodels/scenic/configs/mae_lsm_base.py:runlocal
"""

from typing import Optional
import ml_collections
from google3.experimental.largesensormodels.scenic.models.lsm_vit_utils import model_constants
# from google3.experimental.largesensormodels.scenic.utils import linear_probe_config


# To set constants.
# 1) Dataset variables.
DATASET_NAME = 'lsm_300min_pretraining_165K_n10'
CACHE_DATASET = True
TRAIN_DATA_SIZES = [1000, 10000, 100000, 750000, 1321235]
USE_DATETIME_FEATURES = False
USE_TRAIN_AUGMENTATIONS = [True]
TRAIN_AUGMENTATIONS = ['stretch', 'flip', 'noise']
SHUFFLE_SEED = 42

# 2) Training / eval variables.
BATCH_SIZE = 4096
NUM_TRAIN_STEPS = 50000
LRS = [5e-3]
WEIGHT_DECAYS = [1e-4]

# 3) Logging variables.
LOG_EVAL_SUMMARY_STEPS = 500  # STEPS_PER_EPOCH
LOG_CHECKPOINT_STEPS = 100  # LOG_EVAL_SUMMARY_STEPS * 5
MAX_NUM_CHECKPOINTS = int(NUM_TRAIN_STEPS / LOG_CHECKPOINT_STEPS)

# Model variant / patch H (time steps) / patch W (features)
VARIANT = 'B/10/5'
TOKEN_MASK_PROB = 'constant_0.8'
LOSS_ONLY_MASKED_TOKENS = True

# Downstream Tasks.
# Imputation and forecast eval
RECONSTRUCTION_HORIZONS = [0.1, 0.2, 0.4]

# Linear probe eval.
LINEAR_PROBE_USE_TRAIN_AUGMENTATIONS = False
LINEAR_PROBE_TRAIN_AUGMENTATIONS = ['noise']


def get_config_common_few_shot(
    batch_size: Optional[int] = None,
    target_resolution: int = 224,
    resize_resolution: int = 256,
) -> ml_collections.ConfigDict:
  """Returns a standard-ish fewshot eval configuration.

  Copied from
  third_party/py/scenic/projects/baselines/configs/google/common/common_fewshot.py

  Args:
    batch_size: The batch size to use for fewshot evaluation.
    target_resolution: The target resolution of the fewshot evaluation.
    resize_resolution: The resize resolution of the fewshot evaluation.

  Returns:
    A ConfigDict with the fewshot evaluation configuration.
  """
  config = ml_collections.ConfigDict()
  config.batch_size = batch_size
  config.representation_layer = 'pre_logits'
  config.log_eval_steps = 25_000
  config.datasets = {
      'birds': ('caltech_birds2011', 'train', 'test'),
      'caltech': ('caltech101', 'train', 'test'),
      'cars': ('cars196:2.1.0', 'train', 'test'),
      'cifar100': ('cifar100', 'train', 'test'),
      'col_hist': ('colorectal_histology', 'train[:2000]', 'train[2000:]'),
      'dtd': ('dtd', 'train', 'test'),
      'imagenet': ('imagenet2012_subset/10pct', 'train', 'validation'),
      'pets': ('oxford_iiit_pet', 'train', 'test'),
      'uc_merced': ('uc_merced', 'train[:1000]', 'train[1000:]'),
  }
  config.pp_train = f'decode|resize({resize_resolution})|central_crop({target_resolution})|value_range(-1,1)'
  config.pp_eval = f'decode|resize({resize_resolution})|central_crop({target_resolution})|value_range(-1,1)'
  config.shots = [1, 5, 10, 25]
  config.l2_regs = [2.0**i for i in range(-10, 20)]
  config.walk_first = ('imagenet', 10)

  return config


def get_config(runlocal=''):
  """Returns the ViT experiment configuration."""

  runlocal = bool(runlocal)

  # Experiment.
  config = ml_collections.ConfigDict()
  config.experiment_name = f'electrodes-mae-{DATASET_NAME}'
  config.dataset_name = f'lsm_prod/{DATASET_NAME}'
  config.shuffle_seed = SHUFFLE_SEED

  # Dataset.
  config.data_dtype_str = 'float32'
  config.dataset_configs = ml_collections.ConfigDict()
  config.dataset_configs.dataset = f'lsm_prod/{DATASET_NAME}'
  config.dataset_configs.num_classes = None
  config.dataset_configs.train_split = 'train'  # train data split
  config.dataset_configs.train_num_samples = TRAIN_DATA_SIZES[0]  # train sample
  # eval data split - note: this split is used for validation and test.
  config.dataset_configs.eval_split = 'test[:64]' if runlocal else 'test'
  config.dataset_configs.cache_dataset = CACHE_DATASET
  config.dataset_configs.prefetch_to_device = 2
  # Shuffle_buffer_size is per host, so small-ish is ok.
  config.dataset_configs.shuffle_buffer_size = 250_000

  # Model.
  if len(VARIANT.split('/')) == 3:
    version = VARIANT.split('/')[0]  # model variant
    patch_h = VARIANT.split('/')[1]  # patch width
    patch_w = VARIANT.split('/')[2]  # patch height
  elif len(VARIANT.split('/')) == 2:
    version = VARIANT.split('/')[0]  # model variant
    patch_h = VARIANT.split('/')[1]  # patch width
    patch_w = VARIANT.split('/')[1]  # patch height
  else:
    raise ValueError(f'Invalid model variant: {VARIANT}')

  version = 'Deb' if runlocal else version
  config.model_name = 'lsm_vit_mae'
  config.model = ml_collections.ConfigDict()
  # encoder
  config.model.hidden_size = model_constants.HIDDEN_SIZES[version]
  config.model.patches = ml_collections.ConfigDict()
  config.model.patches.size = tuple([int(patch_h), int(patch_w)])
  config.model.num_heads = model_constants.NUM_HEADS[version]
  config.model.mlp_dim = model_constants.MLP_DIMS[version]
  config.model.num_layers = model_constants.NUM_LAYERS[version]
  config.model.dropout_rate = 0.0
  config.model.classifier = 'none'  # Has to be "none" for the autoencoder
  config.model.representation_size = None
  config.model.positional_embedding = 'sinusoidal_2d'
  config.model.positional_embedding_decoder = 'sinusoidal_2d'
  # decoder
  config.model.decoder_config = ml_collections.ConfigDict()
  config.model.decoder_config.hidden_size = (
      model_constants.DECODER_HIDDEN_SIZES[version]
  )
  config.model.decoder_config.mlp_dim = model_constants.DECODER_MLP_DIMS[
      version
  ]
  config.model.decoder_config.num_layers = model_constants.DECODER_NUM_LAYERS[
      version
  ]
  config.model.decoder_config.num_heads = model_constants.DECODER_NUM_HEADS[
      version
  ]
  config.model.decoder_config.dropout_rate = 0.0
  config.model.decoder_config.attention_dropout_rate = 0.0

  config.masked_feature_loss = ml_collections.ConfigDict()
  config.masked_feature_loss.targets_type = 'rgb'
  config.masked_feature_loss.token_mask_probability = TOKEN_MASK_PROB
  config.masked_feature_loss.loss_only_masked_tokens = LOSS_ONLY_MASKED_TOKENS
  config.masked_feature_loss.loss_type = 'squared'  # 'squared' or 'absolute'

  # Datetime features.
  config.use_datetime_features = USE_DATETIME_FEATURES

  # Training.
  config.trainer_name = 'lsm_mae_trainer'
  config.batch_size = 8 if runlocal else BATCH_SIZE
  config.num_training_steps = NUM_TRAIN_STEPS
  config.log_eval_steps = LOG_EVAL_SUMMARY_STEPS
  config.log_summary_steps = LOG_EVAL_SUMMARY_STEPS
  config.rng_seed = 42
  config.use_train_augmentations = USE_TRAIN_AUGMENTATIONS[0]
  config.train_augmentations = TRAIN_AUGMENTATIONS
  sched = ml_collections.ConfigDict()
  sched.re = '(.*)'
  sched.lr_configs = ml_collections.ConfigDict()
  sched.lr_configs.learning_rate_schedule = 'compound'
  sched.lr_configs.factors = 'constant * cosine_decay * linear_warmup'
  sched.lr_configs.total_steps = NUM_TRAIN_STEPS
  sched.lr_configs.steps_per_cycle = sched.lr_configs.total_steps
  sched.lr_configs.warmup_steps = int(NUM_TRAIN_STEPS * 0.05)
  sched.lr_configs.base_learning_rate = LRS[0]
  config.schedule = ml_collections.ConfigDict({'all': sched})

  # *Single* optimizer.
  optim = ml_collections.ConfigDict()
  optim.optax_name = 'scale_by_adam'
  # optim.optax = dict(mu_dtype='bfloat16')
  optim.optax_configs = ml_collections.ConfigDict({  # Optimizer settings.
      'b1': 0.9,
      'b2': 0.95,
  })
  config.optax = dict(mu_dtype='bfloat16')
  optim.max_grad_norm = 1.0

  optim.weight_decay = WEIGHT_DECAYS[0]
  optim.weight_decay_decouple = True
  config.optimizer = optim

  # # Downstream Tasks.
  # # 0) Linear Probing.
  # config.linear_probe_gather_to_host = False if runlocal else False
  # config.linear_probe_representation_layer = 'pre_logits'
  # config.linear_probe_log_eval_steps = LOG_EVAL_SUMMARY_STEPS
  # config.linear_probe = linear_probe_config.get_linear_probe_config(
  #     log_eval_steps=LOG_EVAL_SUMMARY_STEPS,
  #     model_config=config.model,
  #     use_datetime_features=USE_DATETIME_FEATURES,
  #     use_train_augmentations=LINEAR_PROBE_USE_TRAIN_AUGMENTATIONS,
  #     train_augmentations=LINEAR_PROBE_TRAIN_AUGMENTATIONS,
  #     cache_dataset=CACHE_DATASET,
  #     runlocal=runlocal,
  #     masked_feature_loss=LOSS_ONLY_MASKED_TOKENS,
  # )
  # # 1) Fewshot.
  # # TODO(girishvn): This needs to be adapted to electrode dataset
  # config.fewshot = get_config_common_few_shot(batch_size=config.batch_size)
  # config.fewshot.datasets = {}
  # config.fewshot.walk_first = ()
  # config.fewshot.representation_layer = 'pre_logits'
  # config.fewshot.log_eval_steps = LOG_EVAL_SUMMARY_STEPS

  # 2) Reconstruction Eval Tasks (Forecast and Imputation).
  config.forecast = ml_collections.ConfigDict()
  config.forecast.horizons = RECONSTRUCTION_HORIZONS
  config.imputation = ml_collections.ConfigDict()
  config.imputation.horizons = RECONSTRUCTION_HORIZONS

  # Logging.
  config.write_summary = True
  config.xprof = True  # Profile using xprof.
  config.checkpoint = True  # Do checkpointing.
  config.checkpoint_steps = LOG_CHECKPOINT_STEPS
  config.debug_train = False  # Debug mode during training.
  config.debug_eval = False  # Debug mode during eval.
  config.max_checkpoints_to_keep = MAX_NUM_CHECKPOINTS
  # BEGIN GOOGLE-INTERNAL
  if runlocal:
    # Current implementation fails with UPTC.
    config.count_flops = False
  # END GOOGLE-INTERNAL

  return config


# BEGIN GOOGLE-INTERNAL
def get_hyper(hyper):
  """Defines the hyper-parameters sweeps for doing grid search."""
  return hyper.product([
      hyper.sweep('config.schedule.all.lr_configs.base_learning_rate', LRS),
      hyper.sweep('config.optimizer.weight_decay', WEIGHT_DECAYS),
      hyper.sweep('config.dataset_configs.train_num_samples', TRAIN_DATA_SIZES),
      hyper.sweep('config.use_train_augmentations', USE_TRAIN_AUGMENTATIONS),
  ])


In [None]:
# @title Available Compute

print("Available devices:", jax.devices())
print("Default device:", jax.default_backend())

In [None]:
# @title 80% Random Patching

# Things to set
TOKEN_MASK_PROB = 'constant_0.8_random'  # update masking strategy
workdir = '/cns/dz-d/home/xliucs/lsm/xm/124248847/5/'  # working directory
step = None  # change the step to reflect the desired checkpoint step

# Derived values
work_dir_split = workdir.split('/')
work_dir_split = [i for i in work_dir_split if i != '']
xid = work_dir_split[-2]
wid = work_dir_split[-1]

print('Workdir', workdir)
print('XID', xid)
print('WID', wid)

# Run pipeline
config = get_config(runlocal=False)  # get configs
rng = jax.random.PRNGKey(config.rng_seed)
model_cls = get_model_cls(config.model_name)
data_rng, rng = jax.random.split(rng)
dataset = lsm_tiny_dataset.get_dataset(config, data_rng)

output_plots = run_eval_from_ckpt(
    rng, config, model_cls, dataset, workdir, step=step
)

In [None]:
# @title 50% Temporal Imputation Patching

# Things to set
TOKEN_MASK_PROB = 'constant_0.5_imputation_time'  # update masking strategy
workdir = '/cns/dz-d/home/xliucs/lsm/xm/117310802/1/'  # working directory
step = None  # change the step to reflect the desired checkpoint step

# Derived values
work_dir_split = workdir.split('/')
work_dir_split = [i for i in work_dir_split if i != '']
xid = work_dir_split[-2]
wid = work_dir_split[-1]

print('Workdir', workdir)
print('XID', xid)
print('WID', wid)

# Run pipeline
config = get_config(runlocal=False)  # get configs
rng = jax.random.PRNGKey(config.rng_seed)
model_cls = get_model_cls(config.model_name)
data_rng, rng = jax.random.split(rng)
dataset = lsm_tiny_dataset.get_dataset(config, data_rng)

output_plots = run_eval_from_ckpt(
    rng, config, model_cls, dataset, workdir, step=step
)

In [None]:
# @title 25% Temporal Forecast Patching

# Things to set
TOKEN_MASK_PROB = 'constant_0.25_forecast_time'  # update masking strategy
workdir = '/cns/dz-d/home/xliucs/lsm/xm/117310802/1/'  # working directory
step = None  # change the step to reflect the desired checkpoint step

# Derived values
work_dir_split = workdir.split('/')
work_dir_split = [i for i in work_dir_split if i != '']
xid = work_dir_split[-2]
wid = work_dir_split[-1]

print('Workdir', workdir)
print('XID', xid)
print('WID', wid)

# Run pipeline
config = get_config(runlocal=False)  # get configs
rng = jax.random.PRNGKey(config.rng_seed)
model_cls = get_model_cls(config.model_name)
data_rng, rng = jax.random.split(rng)
dataset = lsm_tiny_dataset.get_dataset(config, data_rng)

output_plots = run_eval_from_ckpt(
    rng, config, model_cls, dataset, workdir, step=step
)

In [None]:
# @title 35% Feature Imputation Patching

# Things to set
TOKEN_MASK_PROB = 'constant_0.35_imputation_feature'  # update masking strategy
workdir = '/cns/dz-d/home/xliucs/lsm/xm/117310802/1/'  # working directory
step = None  # change the step to reflect the desired checkpoint step

# Derived values
work_dir_split = workdir.split('/')
work_dir_split = [i for i in work_dir_split if i != '']
xid = work_dir_split[-2]
wid = work_dir_split[-1]

print('Workdir', workdir)
print('XID', xid)
print('WID', wid)

# Run pipeline
config = get_config(runlocal=False)  # get configs
rng = jax.random.PRNGKey(config.rng_seed)
model_cls = get_model_cls(config.model_name)
data_rng, rng = jax.random.split(rng)
dataset = lsm_tiny_dataset.get_dataset(config, data_rng)

output_plots = run_eval_from_ckpt(
    rng, config, model_cls, dataset, workdir, step=step
)