# MAX Colab Example
Licensed under the Apache License, Version 2.0

## Imports

In [1]:
import functools
import dataclasses
import re

import tensorflow as tf
import gin
from six.moves import reload_module
import pandas as pd
import tqdm
import numpy as np
import tensorflow_datasets as tfds
import jax
from jax import numpy as jnp
import flax
import flax.linen as nn
import optax
import seqio
from pprint import pprint


# Create logging handlers to catch logging or absl.logging outputs in the notebook
import logging
from absl import logging as absl_logging

# Set up the logging configuration
logging.basicConfig(level=logging.INFO)
absl_logging.set_verbosity('info')

# Redirect absl logging to standard Python logging
absl_logging.use_absl_handler()

# Define an emit function to simply catch the log entry and print it
class JupyterHandler(logging.Handler):
  def emit(self, record):
    log_entry = self.format(record)
    print(log_entry)

# Add the custom handler to the root logger
jupyter_handler = JupyterHandler()
jupyter_handler.setLevel(logging.INFO)
logging.getLogger().addHandler(jupyter_handler)




All available devices:
TPU_0(process=0,(0,0,0,0))
TPU_1(process=0,(0,0,0,1))
TPU_2(process=0,(1,0,0,0))
TPU_3(process=0,(1,0,0,1))
TPU_4(process=0,(0,1,0,0))
TPU_5(process=0,(0,1,0,1))
TPU_6(process=0,(1,1,0,0))
TPU_7(process=0,(1,1,0,1))


Successfully created mesh:
Mesh('data': 4, 'model': 2)


In [None]:
import mediapy as media
import t5
import clu
from clu.data import dataset_iterator
from clu import parameter_overview

from max import modeling as mnn
from max.config import base as base_config
from max.config import registry
from max.config import validators
from max.core import constants
from max.data import config as data_config
from max.data import loading
from max.data import processing
from max.data import tokenizers
from max.data.datasets import config as datasets_config
from max.data.datasets import dataloader
from max.execution import checkpointing
from max.execution import config as exec_config
from max.execution import executors
from max.modeling import config as modeling_config
from max.modeling import linear
from max.modeling import multimodal
from max.modeling import stochastic
from max.modeling import transformers
from max.modeling.garden import config as garden_config
from max.modeling.garden import imp
from max.modeling.garden import vit
from max.optimization import config as opt_config
from max.optimization import objectives
from max.projects.imp.config import data as imp_data_config
from max.projects.imp.config import experiment as imp_exp_config
from max.projects.imp.config import model as imp_model_config
from max.utils import sharding
from max.utils import typing

In [None]:
Registrar = registry.Registrar
TaskRegistry = seqio.TaskRegistry

VOCABULARY = t5.data.get_default_vocabulary()

print(*('All available devices:', *jax.devices(), '\n'), sep='\n')

mesh = jax.sharding.Mesh(
    sharding.create_tpu_device_mesh(ici_mesh_shape=(4, 2),
                                    dcn_mesh_shape=(1, 1)),
    ['data', 'model'],
)
print(f'Successfully created mesh:\n{mesh}')

## Dataset Preview

In [None]:
def load_dataset(dataset_config, batch_size=8, num_frames=32, resolution=128):
  batch_size = 8
  experiment = data_config.ExperimentData(
      vision_spatial_size=(resolution, resolution, 3),
      vision_spatial_patch_size=(resolution, resolution),
      vision_temporal_size=num_frames,
      vision_temporal_patch_size=num_frames,
      waveform_temporal_size=7680,
      waveform_temporal_patch_size=7680,
      text_size=None,
      num_epochs=-1,
      loaders=[
          data_config.Loader(dataset=dataset_config, batch_size=batch_size, shuffle=False),
      ],
      is_training=False)
  loaders = dataloader.create_data(experiment)

  ds = loaders[0]['loader']

  return ds

def decode_videos(outputs):
  videos = tf.reshape(outputs['inputs']['encoder']['vision']['token_raw'][:, 0, 0], (-1, 32, 128, 128, 3))
  videos = (videos.numpy() + 1.) / 2.
  return videos

def decode_text_labels(outputs):
  labels = VOCABULARY.decode_tf(outputs['inputs']['encoder']['text']['token_id'])
  labels = [label[0].decode('utf8') for label in labels.numpy()]
  return labels

In [None]:
# Available datasets
for ds in datasets_config.ALL_DATASETS:
  print(ds.name)

In [None]:
ds = load_dataset(datasets_config.IMAGENET)
ds_iter = iter(ds)
outputs = next(ds_iter)

images = decode_videos(outputs)[:, 0]
media.show_images(images)

labels = decode_text_labels(outputs)
print('\n'.join(labels))

In [None]:
ds = load_dataset(datasets_config.KINETICS400)
ds_iter = iter(ds)
outputs = next(ds_iter)

videos = decode_videos(outputs)
media.show_videos(videos, fps=25, codec='gif')

labels = decode_text_labels(outputs)
print('\n'.join(labels))

In [None]:
{k: v.shape for k, v in flax.traverse_util.flatten_dict(outputs, sep='/').items()}

{'inputs/encoder/spectrogram/token_coordinate': (8, 1, 32, 2),
 'inputs/encoder/spectrogram/token_position_id': (8, 1, 640, 2),
 'inputs/encoder/spectrogram/token_raw': (8, 1, 32, 20),
 'inputs/encoder/text/token_coordinate': (8, 1, 16),
 'inputs/encoder/text/token_id': (8, 1, 16),
 'inputs/encoder/text/token_mask': (8, 1, 16),
 'inputs/encoder/text/token_position_id': (8, 1, 16),
 'inputs/encoder/vision/token_coordinate': (8, 1, 1, 3),
 'inputs/encoder/vision/token_position_id': (8, 1, 1),
 'inputs/encoder/vision/token_raw': (8, 1, 1, 1572864),
 'inputs/encoder/waveform/token_coordinate': (8, 1, 1),
 'inputs/encoder/waveform/token_position_id': (8, 1, 1),
 'inputs/encoder/waveform/token_raw': (8, 1, 1, 7680),
 'targets/label_classifier/vision/label': (8, 1, 400)}

## Run Model

In [None]:
# You can make changes live and reload your changed modules
# reload_module(imp)

In [None]:
config = imp_exp_config.ImpBaseImgTrainExperiment(path="")
config.model.input_batch_size = 8
config.model.vision_input_size = (4, 256, 256, 3)
config.model.waveform_input_size = 1024
config.model.spectrogram_input_size = (64, 64)
model = imp.IMP(**config.model.as_dict())

# initialize the model
_, prng_keys = executors.get_rngs(model.get_rng_keys(), add_params_rngs=True)
inputs = model.get_data_signature()

with mesh:
  params = model.init(prng_keys, inputs, True)
  params = nn.unbox(params)

all_params = {k: v for k, v in flax.traverse_util.flatten_dict(params['params'], sep='/').items()}
encoder_params = {k: v for k, v in all_params.items() if 'transformer_encoder' in k}

print(f"total params:\t{executors.count_params(all_params) / 10 ** 6} M")
print(f"encoder params:\t{executors.count_params(encoder_params) / 10 ** 6} M")

total params:	399.790912 M
encoder params:	84.954624 M


In [None]:
print(parameter_overview.get_parameter_overview(params["params"]))

+--------------------------------------------------------------------------------------------+-------------------+---------+------------+-----------+--------+
| Name                                                                                       | Shape             | Dtype   | Size       | Mean      | Std    |
+--------------------------------------------------------------------------------------------+-------------------+---------+------------+-----------+--------+
| spectrogram_token_embed_disjoint_projection/to_text/layer_norm/bias                        | (1024,)           | float32 | 1,024      | 0.0       | 0.0    |
| spectrogram_token_embed_disjoint_projection/to_text/layer_norm/scale                       | (1024,)           | float32 | 1,024      | 1.0       | 0.0    |
| spectrogram_token_embed_disjoint_projection/to_text/wi/bias                                | (1024,)           | float32 | 1,024      | 0.0       | 0.0    |
| spectrogram_token_embed_disjoint_projection/

In [None]:
# Perform a sample inference
model_fn = jax.jit(model.apply, static_argnums=(2, 3))
with mesh:
  outputs = model_fn(params, inputs, False, rngs=prng_keys)

jax.tree.map(lambda x: x.shape, outputs)

{'hyperparams': {'encoder': {'temperature': {'spectrogram_spectrogram': (),
    'spectrogram_text': (),
    'spectrogram_vision': (),
    'spectrogram_waveform': (),
    'text_text': (),
    'text_vision': (),
    'text_waveform': (),
    'vision_vision': (),
    'vision_waveform': (),
    'waveform_waveform': ()}}},
 'inputs': {'encoder': {'spectrogram': {'token_coordinate': (8, 1, 16, 2),
    'token_embed': (8, 1, 16, 768),
    'token_id': (8, 1, 16),
    'token_raw': (8, 1, 16, 256)},
   'text': {'token_coordinate': (8, 1, 4096),
    'token_embed': (8, 1, 4096, 768),
    'token_id': (8, 1, 4096),
    'token_mask': (8, 1, 4096)},
   'vision': {'token_coordinate': (8, 1, 256, 3),
    'token_embed': (8, 1, 256, 768),
    'token_id': (8, 1, 256),
    'token_raw': (8, 1, 256, 3072)},
   'waveform': {'token_coordinate': (8, 1, 4),
    'token_embed': (8, 1, 4, 768),
    'token_id': (8, 1, 4),
    'token_raw': (8, 1, 4, 256)}}},
 'outputs': {'common_space': {'spectrogram': {'token_embed': {

## Train & Eval Model

Perform an end-to-end training run.

You can insert breakpoints and inspect arrays using `jax.debug`.
See https://jax.readthedocs.io/en/latest/debugging/index.html

### IMP

In [None]:
# %debug  # Uncomment for debugging errors

# train run

VISION = constants.Modality.VISION
TEXT = constants.Modality.TEXT
ENCODER = constants.DataFeatureRoute.ENCODER
BASE_TRAIN_BATCH_SIZE = 8
SequenceLength = constants.SequenceLength

model_dir = '/tmp/imp/'
dataset=datasets_config.CONCEPTUAL_CAPTIONS_3M
VISION_INPUT_SIZE = (4, 224, 224, 3)
VISION_PATCH_SIZE = (4, 16, 16)
_VISION_OVERRIDE = {
    **imp_data_config._vision_override(VISION_INPUT_SIZE),
    'temporal_patch_size': VISION_PATCH_SIZE[0],
}
_TEXT_OVERRIDE = {
    'max_num_tokens': SequenceLength.SMALL,
}
if tf.io.gfile.exists(model_dir):
  tf.io.gfile.rmtree(model_dir)

train_config = imp_exp_config.ImpBaseImgTrainExperiment(
    path=model_dir,
    data=imp_exp_config.BasePreTrainExperimentData(
        batch_size=BASE_TRAIN_BATCH_SIZE,
        loaders=(
            data_config.Loader(
            interval=1,
            num_epochs=-1,
            is_training=True,
            batch_size=BASE_TRAIN_BATCH_SIZE,
            microbatch_splits=1,
            metadata=imp_data_config.get_contrastive_metadata((VISION, TEXT)),
            dataset=dataset.copy_and_override({
                'modalities': {
                    'vision': _VISION_OVERRIDE,
                    'text': _TEXT_OVERRIDE,
                },
            }),
            ),
        ),
    ),
    optimization=imp_exp_config.BasePreTrainOptimization(
        total_steps=10,
        loss=(opt_config.CrossModalNCE(
            modality_pair_weights=(
                ((VISION, TEXT), 1.),
            ),
            hparams_route_key=ENCODER,
            dtype=jnp.float32,
          ),
        ),
    ),
    model=imp_model_config.BaseIMP(
        input_batch_size=BASE_TRAIN_BATCH_SIZE,
        vision_input_size=VISION_INPUT_SIZE,
        vision_patch_size=VISION_PATCH_SIZE,
        text_input_size=SequenceLength.SMALL,
    ),
)

with mesh:
  model = imp.IMP(**train_config.model.as_dict())
  data = dataloader.create_data(train_config.data)
  executor = executors.Executor(model, data, train_config)
  executor.run(train_config.mode)

In [65]:
# Run TensorBoard

%load_ext tensorboard
%tensorboard --logdir=/tmp/imp/ --port=0

In [None]:
# eval run
DataTuning = constants.DataTuning
ServingStrategy = constants.ServingStrategy
TINY_EVAL_BATCH_SIZE = 8
eval_config = imp_exp_config.ImpBaseImgEvalExperiment(path=model_dir)
eval_config.model = train_config.model
eval_config.optimization = train_config.optimization
dataset=datasets_config.FLICKR30K
eval_config.data.loaders = (
    data_config.Loader(
        num_epochs=1,
        is_training=False,
        batch_size=TINY_EVAL_BATCH_SIZE,
        microbatch_splits=1,
        prefetch=1,
        tuning=DataTuning.EFFICIENT,
        metadata=imp_data_config.get_contrastive_metadata((VISION, TEXT)),
        dataset=datasets_config.FLICKR30K.copy_and_override({
            'data': {
                'table': 'test'
            },
            'modalities': {
                'vision': _VISION_OVERRIDE,
                'text': _TEXT_OVERRIDE,
            },
        }),
        serving=ServingStrategy.BULK_ZS_RETRIEVAL,
    ),
)

with mesh:
  model = imp.IMP(**eval_config.model.as_dict())
  data = dataloader.create_data(eval_config.data)
  executor = executors.Executor(model, data, eval_config)
  executor.run(eval_config.mode)

## Load Checkpoints

### IMP Base

In [100]:
import flax

# TODO: make this work with partitioned checkpoints

checkpoint_path = '/path/to/checkpoint'

imp_base_params = flax.training.checkpoints.restore_checkpoint(
  ckpt_dir=checkpoint_path,
  target=None,
)
imp_base_params_flat = flax.traverse_util.flatten_dict(imp_base_params['params'], sep="/")

for k, v in imp_base_params_flat.items():
  print(k, v.shape)

params/audio_raw_to_embed/wav_pos_encoding/layer_norm/bias (768,)
params/audio_raw_to_embed/wav_pos_encoding/layer_norm/scale (768,)
params/audio_raw_to_embed/wav_pos_encoding/temporal_postition_embeddings/embedding (300, 768)
params/audio_raw_to_embed/wav_to_embedding/bias (768,)
params/audio_raw_to_embed/wav_to_embedding/kernel (256, 768)
params/disjoint_projection/audio_to_text/layer_norm/bias (1024,)
params/disjoint_projection/audio_to_text/layer_norm/scale (1024,)
params/disjoint_projection/audio_to_text/wi/bias (1024,)
params/disjoint_projection/audio_to_text/wi/kernel (768, 1024)
params/disjoint_projection/audio_to_text/wo/bias (1024,)
params/disjoint_projection/audio_to_text/wo/kernel (1024, 1024)
params/disjoint_projection/audio_to_vision/layer_norm/bias (1024,)
params/disjoint_projection/audio_to_vision/layer_norm/scale (1024,)
params/disjoint_projection/audio_to_vision/wi/bias (1024,)
params/disjoint_projection/audio_to_vision/wi/kernel (768, 1024)
params/disjoint_projection

### IMP Huge

In [101]:
# TODO: make this work with partitioned checkpoints
checkpoint_path = '/path/to/checkpoint'

imp_huge_params = flax.training.checkpoints.restore_checkpoint(
  ckpt_dir=checkpoint_path,
  target=None,
)

imp_huge_params_flat = flax.traverse_util.flatten_dict(imp_huge_params['params'], sep="/")

# Scan groups together all transformer layers into one, so we split the
# weights into individual layers.
def split_scanned_params(params, scan_axis=0):
  output_params = {}

  for name, param in params.items():
    if not name.startswith('params/transformer_encoder/layer_scan'):
      output_params[name] = param
    else:
      for i in range(param.shape[scan_axis]):
        suffix = name.split('params/transformer_encoder/layer_scan/')[-1]
        gather = (slice(None),) * scan_axis + (i,)
        output_params[f'params/transformer_encoder/layer_{i}/{suffix}'] = param[gather]
  return output_params

imp_huge_params_split_flat = split_scanned_params(imp_huge_params_flat)

for k, v in imp_huge_params_split_flat.items():
  print(k, v.shape)

params/audio_raw_to_embed/wav_pos_encoding/layer_norm/bias (1536,)
params/audio_raw_to_embed/wav_pos_encoding/layer_norm/scale (1536,)
params/audio_raw_to_embed/wav_pos_encoding/temporal_postition_embeddings/embedding (600, 1536)
params/audio_raw_to_embed/wav_to_embedding/bias (1536,)
params/audio_raw_to_embed/wav_to_embedding/kernel (256, 1536)
params/disjoint_projection/audio_to_text/layer_norm/bias (1024,)
params/disjoint_projection/audio_to_text/layer_norm/scale (1024,)
params/disjoint_projection/audio_to_text/wi/bias (1024,)
params/disjoint_projection/audio_to_text/wi/kernel (1536, 1024)
params/disjoint_projection/audio_to_text/wo/bias (1024,)
params/disjoint_projection/audio_to_text/wo/kernel (1024, 1024)
params/disjoint_projection/audio_to_vision/layer_norm/bias (1024,)
params/disjoint_projection/audio_to_vision/layer_norm/scale (1024,)
params/disjoint_projection/audio_to_vision/wi/bias (1024,)
params/disjoint_projection/audio_to_vision/wi/kernel (1536, 1024)
params/disjoint_pro