In [None]:
# @title Imports
# import pytest

import functools
import json

from flax import jax_utils
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import ml_collections
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_datasets as tfds

from google3.pyglib import build_data
from google3.pyglib import gfile
# build_data.Changelist()

from colabtools import adhoc_import
# this is necessary, otherwise breaks the adhoc_import.Google3CitcClient imports
with adhoc_import.Google3():
  from scenic.model_lib.base_models import base_model
  from scenic.projects.multimask.models import model_utils as mm_model_utils
  from scenic.model_lib.base_models import model_utils

# this needs to run for some reason, otherwise adhoc_import.Google3CitcClient import breaks
tf.data.Dataset

# mask_to_patchmask test

In [None]:
from importlib import reload
from colabtools import adhoc_import

with adhoc_import.Google3CitcClient(
    'lsm_maskoncpu_25_2_16', username='xumax', behavior='preferred'
):
  from google3.experimental.largesensormodels.scenic.datasets import dataset_utils
  dataset_utils = reload(dataset_utils)
  mask_to_patchmask = dataset_utils.mask_to_patchmask

In [None]:
mask = jnp.array([
    [1, 1, 0, 0],
    [1, 1, 0, 0],
    [0, 0, 1, 1],
    [0, 0, 1, 1]
]).reshape(4, 4, 1)

input_size = (4, 4, 1)
patch_size = (2, 2)

expected_output = jnp.array([
    [1, 0],
    [0, 1]
], dtype=jnp.float32)

output = mask_to_patchmask(mask, input_size, patch_size, mechanism='absolute')

assert jnp.array_equal(output, expected_output), f"Expected {expected_output}, got {output}"


In [None]:
mask = jnp.array([
    [1, 1, 0, 0],
    [1, 0, 0, 0],
    [0, 0, 1, 1],
    [0, 0, 1, 0]
]).reshape(4, 4, 1)

input_size = (4, 4, 1)
patch_size = (2, 2)

expected_output = jnp.array([
    [1, 0],
    [0, 1]
], dtype=jnp.float32)

output = mask_to_patchmask(mask, input_size, patch_size, mechanism='1threshold', thresh_amt=0.5)

assert jnp.array_equal(output, expected_output), f"Expected {expected_output}, got {output}"

In [None]:
mask = jnp.array([
    [1, 1, 0, 0],
    [1, 0, 0, 0],
    [0, 0, 1, 1],
    [0, 0, 1, 0]
]).reshape(4, 4, 1)

input_size = (4, 4, 1)
patch_size = (2, 2)


try:
  output = mask_to_patchmask(mask, input_size, patch_size, mechanism='fakemechanism', thresh_amt=0.5)
  assert False
except ValueError:
  pass

# get_random_mask_afterinputmask_indices test

In [None]:
from importlib import reload
from colabtools import adhoc_import

with adhoc_import.Google3CitcClient(
    'lsm3_bettermae_25_2_12', username='xumax', behavior='preferred'
):
  from google3.experimental.largesensormodels.scenic.datasets import dataset_utils
  dataset_utils = reload(dataset_utils)
  get_random_mask_afterinputmask_indices = dataset_utils.get_random_mask_afterinputmask_indices

In [None]:
n_tokens = 10
n_masked = 5
existing_mask = tf.constant([0, 1, 0, 1, 0, 0, 0, 0, 0, 0], dtype=tf.int32)
seed = 42

mask_inds, unmasked_inds, binary_mask = get_random_mask_afterinputmask_indices(n_tokens, n_masked, existing_mask, seed)

assert tf.reduce_sum(binary_mask).numpy() == n_masked
assert len(mask_inds.numpy()) == n_masked
assert len(unmasked_inds.numpy()) == n_tokens - n_masked

In [None]:
n_tokens = 8
n_masked = 4
existing_mask = tf.zeros([n_tokens], dtype=tf.int32)
seed = 42

mask_inds, unmasked_inds, binary_mask = get_random_mask_afterinputmask_indices(n_tokens, n_masked, existing_mask, seed)

assert tf.reduce_sum(binary_mask).numpy() == n_masked
assert len(mask_inds.numpy()) == n_masked
assert len(unmasked_inds.numpy()) == n_tokens - n_masked
assert set(mask_inds.numpy()).isdisjoint(set(unmasked_inds.numpy()))

In [None]:
n_tokens = 6
n_masked = 2
existing_mask = tf.constant([1, 1, 1, 0, 0, 0], dtype=tf.int32)
seed = 42

try:
  get_random_mask_afterinputmask_indices(n_tokens, n_masked, existing_mask, seed)
  assert False
except:
  pass

# mask_example test

In [None]:
from importlib import reload
from colabtools import adhoc_import

with adhoc_import.Google3CitcClient(
    'lsm3_bettermae_25_2_12', username='xumax', behavior='preferred'
):
  from google3.experimental.largesensormodels.scenic.datasets import dataset_utils
  dataset_utils = reload(dataset_utils)
  mask_example = dataset_utils.mask_example

In [None]:
example = {
    "imputation_mask": tf.constant(
            [[[0], [1], [0], [0]], [[1], [0], [0], [0]], [[0], [0], [1], [0]]], dtype=tf.int32
        )
}
masking_configs = "randomonexistmask_0.5"
seed = 42
patch_size = (1, 1)
input_size = (3, 4, 1)

masked_example = mask_example(
    example, masking_configs, seed=seed, patch_size=patch_size, input_size=input_size
)

assert "mask_indices" in masked_example
assert "unmasked_indices" in masked_example
assert "token_mask" in masked_example
assert tf.reduce_sum(masked_example["token_mask"]).numpy() == 6