In [1]:
import utils

2025-05-02 01:53:28.966013: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-05-02 01:53:28.966079: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-05-02 01:53:28.966105: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


# Config

In [2]:
from configs.loca_config import get_config

config = get_config()
config.batch_size = 1

# Dataset

In [3]:
from scenic.train_lib import train_utils
import jax.numpy as jnp
import loca_dataset
import ops
import jax

import os
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.75'
os.environ['TFDS_DATA_DIR'] = '/home/admin/john/data/tensorflow_datasets'

rng = jax.random.key(42)
data_rng, rng = jax.random.split(rng)

dataset = train_utils.get_dataset(config, data_rng)


TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.
Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). 

For more information see: https://github.com/tensorflow/addons/issues/2807 

2025-05-02 01:53:35.794154: W tensorflow/tsl/platform/cloud/google_auth_provider.cc:184] All attempts to get a Google authentication bearer token failed, returning an empty token. Retrieving token from files failed with "NOT_FOUND: Could not locate the credentials file.". Retrieving token from GCE failed with "FAILED_PRECONDITION: Error executing an HTTP request: libcurl code 6 meaning 'Couldn't resolve host name', error details: Could not resolve host: metadata.google.internal".
Instructions for updating:
Use `tf.data.Dataset.counter(...)` instead.
Instructions for updating:
Use `tf.data.D

In [4]:
# Utility
def remove_batch_dim0(batch, debug=False):
    # Remove dim 0. (Don't know where extra dim is added at 0)
    for k, v in batch.items():
        batch[k] = v.squeeze(0)
        if debug:
            print(f"batch[{k}]: {batch[k].shape}")

    return batch

In [5]:
batched_sample = next(dataset.train_iter)
batched_sample = remove_batch_dim0(batched_sample)
batched_sample = utils.prepare_input(batched_sample, config)

In [6]:
batched_sample['reference'].shape

(1, 224, 224, 21)

# Model

In [7]:
import flax
from flax.training import checkpoints
import optax
from scenic.train_lib import lr_schedules
import vit
import copy
from scenic.common_lib import debug_utils
import functools

train_state = None

def compute_loca_flops():
    use_ema = config.apply_cluster_loss
    use_pe = True if config.apply_cluster_loss else False
    n_q_foc = config.dataset_configs.number_of_focal_queries

    rng = jax.random.key(42)
    dropout_rng, droptok_rng, changroup_rng = jax.random.split(rng, num=3)

    model = vit.ViTLOCAModel(config, dataset.meta_data)

    
    rng, init_rng = jax.random.split(rng)
    (params, state, num_trainable_params, gflops) = train_utils.initialize_model(
        model_def=model.flax_model,
        input_spec=[
            (dataset.meta_data['input_shape'], dataset.meta_data.get('input_dtype', jnp.float32))],
        config=config, rngs={'params': init_rng, 'changroup': init_rng}
    )
    variables = {'params': params, **state}

    # flops for ref pass
    r_flops = debug_utils.compute_flops(
        flax_model_apply_fn=functools.partial(
            model.flax_model.apply,
            variables,
            train=False,
            debug=False,
            rngs={'dropout': dropout_rng, 'droptok': droptok_rng, 'changroup': changroup_rng},
    
            seqlen=config.reference_seqlen,
            seqlen_selection=config.reference_seqlen_selection,
            drop_moment='late',
        ),
        input_spec=[(batched_sample['reference'].shape, jnp.float32)],
        fuse_multiply_add=True,  # Default
    )

    # Get inputs for query passes
    _, r_feat_targets, r_patch_features, r_idx_kept_tokens, _, r_idx_kept_groups  = model.flax_model.apply(
        {'params': params},
        batched_sample['reference'],
        seqlen=config.reference_seqlen,
        seqlen_selection=config.reference_seqlen_selection,
        drop_moment='late',
        train=False,
        rngs={'dropout': dropout_rng, 'droptok': droptok_rng, 'changroup': changroup_rng}
    )

    # flops for q-rand pass
    q_rand_flops = debug_utils.compute_flops(
        flax_model_apply_fn=functools.partial(
            model.flax_model.apply,
            variables,
            inputs_kv=r_patch_features,
            inputs_kv_kept_groups=r_idx_kept_groups,
            seqlen=config.query_max_seqlen,
            use_pe=use_pe,
            train=False,
            debug=False,
            rngs={'dropout': dropout_rng, 'droptok': droptok_rng, 'changroup': changroup_rng},
        ),
        input_spec=[(batched_sample['query0'].shape, jnp.float32)],
        fuse_multiply_add=True,  # Default
    )

    # flops for q-foc pass
    def model_wrapper(x):
        return model.flax_model.apply(
            variables,  # This is likely your {'params': params}
            batched_sample['queries'],  # This will be the dummy input created by compute_flops
            inputs_kv=jnp.tile(r_patch_features, (n_q_foc, 1, 1)),
            inputs_kv_kept_groups=None if r_idx_kept_groups is None else jnp.tile(r_idx_kept_groups, (n_q_foc, 1)),
            use_pe=use_pe,
            train=False,
            debug=False,
            rngs={'dropout': dropout_rng, 'droptok': droptok_rng, 'changroup': changroup_rng},
        )
    
    q_foc_flops = debug_utils.compute_flops(
        flax_model_apply_fn=functools.partial(model_wrapper),
        input_spec=[(batched_sample['queries'].shape, jnp.float32)],
        fuse_multiply_add=True,  # Default
    )

    print(f'ref: {r_flops / 10**9:.3f}')
    print(f'q_rand: {q_rand_flops / 10**9:.3f}')
    print(f'q_foc: {q_foc_flops / 10**9:.3f}')

    print(f'total: {(r_flops + q_rand_flops + q_foc_flops) / 10**9:.3f}')



# Experiments

## LOCA Vanilla

In [10]:
from configs.loca_config import get_config
config = get_config()
config.batch_size = 1

config.multimodal = None  # None, 'early_fuse_s1_to_rgbn', 'early_fuse_s1_to_all', 'early_concat_s2_and_s1, 'early_concat_s2_and_s1_early_fuse_dem', 'early_concat_s2_s1_dem'
config.use_same_group_attn_mask = False
config.sen2grouped = False
config.sen2grouped_maintain_seqlen = False

patch = 16
reference_resolution = 224
reference_patch_width = reference_resolution // patch
query_rand_res = reference_resolution
query_rand_mask_res = query_rand_res // patch  # Should be equal to patch width/height of rand query
query_foc_res = 96
query_foc_mask_res = query_foc_res // patch  # Should be equal to patch width/height of focal query
n_queries = 10
config.dataset_configs.pp_train = (
  # Sentinel2 preprocessing.
  'permute_channels_last("sentinel2")' +

  '|copy("sentinel2", "reference")' +
  f'|init_patch_matching_tracker({reference_patch_width}, "target_mask")' +
  '|init_box_tracker("target_box")' +
  f'|cropflip_generatemask({reference_resolution}, 32, flip=False, inkey=("reference", "target_mask", "target_box"), outkey=("reference", "target_mask", "target_box"))' +
  ''.join([f'|copy("sentinel2", "query{i}")' for i in range(n_queries)]) +
  f'|inception_crop_with_mask(({query_rand_res}, {query_rand_res}), 32, 100, ({query_rand_mask_res}, {query_rand_mask_res}), inkey=("query0", "target_mask", "target_box"), outkey=("query0", "query0_mask", "query0_box"))' +
  ''.join([
            f'|inception_crop_with_mask(({query_foc_res}, {query_foc_res}), 5, 32, ({query_foc_mask_res}, {query_foc_mask_res}), inkey=("query{i}", "target_mask", "target_box"), outkey=("query{i}", "query{i}_mask", "query{i}_box"))'
            for i in range(1, n_queries)]) +
  ''.join([f'|flip_with_mask(inkey=("query{i}", "query{i}_mask"), outkey=("query{i}", "query{i}_mask"))' for i in
           range(n_queries)]) +
  '|keep("reference"' + ''.join(
[f', "query{i}", "query{i}_box", "query{i}_mask"' for i in range(n_queries)]) + ', "is_l2a")')


compute_loca_flops()

ref: 0.239
q_rand: 0.044
q_foc: 0.719
total: 1.002


## With grouping

In [21]:
from configs.loca_config import get_config
config = get_config()
config.batch_size = 1

config.multimodal = None  # None, 'early_fuse_s1_to_rgbn', 'early_fuse_s1_to_all', 'early_concat_s2_and_s1, 'early_concat_s2_and_s1_early_fuse_dem', 'early_concat_s2_s1_dem'
config.use_same_group_attn_mask = False
config.sen2grouped = True
config.sen2grouped_maintain_seqlen = False
config.sen2changroups = ((1, 2, 3, 7), (4, 5, 6, 8), (10, 11))

patch = 16
reference_resolution = 224
reference_patch_width = reference_resolution // patch
query_rand_res = reference_resolution
query_rand_mask_res = query_rand_res // patch  # Should be equal to patch width/height of rand query
query_foc_res = 96
query_foc_mask_res = query_foc_res // patch  # Should be equal to patch width/height of focal query
n_queries = 10
config.dataset_configs.pp_train = (
  # Sentinel2 preprocessing.
  'permute_channels_last("sentinel2")' +

  '|copy("sentinel2", "reference")' +
  f'|init_patch_matching_tracker({reference_patch_width}, "target_mask")' +
  '|init_box_tracker("target_box")' +
  f'|cropflip_generatemask({reference_resolution}, 32, flip=False, inkey=("reference", "target_mask", "target_box"), outkey=("reference", "target_mask", "target_box"))' +
  ''.join([f'|copy("sentinel2", "query{i}")' for i in range(n_queries)]) +
  f'|inception_crop_with_mask(({query_rand_res}, {query_rand_res}), 32, 100, ({query_rand_mask_res}, {query_rand_mask_res}), inkey=("query0", "target_mask", "target_box"), outkey=("query0", "query0_mask", "query0_box"))' +
  ''.join([
            f'|inception_crop_with_mask(({query_foc_res}, {query_foc_res}), 5, 32, ({query_foc_mask_res}, {query_foc_mask_res}), inkey=("query{i}", "target_mask", "target_box"), outkey=("query{i}", "query{i}_mask", "query{i}_box"))'
            for i in range(1, n_queries)]) +
  ''.join([f'|flip_with_mask(inkey=("query{i}", "query{i}_mask"), outkey=("query{i}", "query{i}_mask"))' for i in
           range(n_queries)]) +
  '|keep("reference"' + ''.join(
[f', "query{i}", "query{i}_box", "query{i}_mask"' for i in range(n_queries)]) + ', "is_l2a")')

compute_loca_flops()

ref: 1.814
q_rand: 0.271
q_foc: 2.163
total: 4.247


## Channel grouping and Group sampling

In [22]:
from configs.loca_config import get_config
config = get_config()
config.batch_size = 1

config.multimodal = None  # None, 'early_fuse_s1_to_rgbn', 'early_fuse_s1_to_all', 'early_concat_s2_and_s1, 'early_concat_s2_and_s1_early_fuse_dem', 'early_concat_s2_s1_dem'
config.use_same_group_attn_mask = False
config.sen2grouped = True
config.sen2grouped_maintain_seqlen = True
config.sen2changroups = ((1, 2, 3, 7), (4, 5, 6, 8), (10, 11))

patch = 16
reference_resolution = 224
reference_patch_width = reference_resolution // patch
query_rand_res = reference_resolution
query_rand_mask_res = query_rand_res // patch  # Should be equal to patch width/height of rand query
query_foc_res = 96
query_foc_mask_res = query_foc_res // patch  # Should be equal to patch width/height of focal query
n_queries = 10
config.dataset_configs.pp_train = (
  # Sentinel2 preprocessing.
  'permute_channels_last("sentinel2")' +

  '|copy("sentinel2", "reference")' +
  f'|init_patch_matching_tracker({reference_patch_width}, "target_mask")' +
  '|init_box_tracker("target_box")' +
  f'|cropflip_generatemask({reference_resolution}, 32, flip=False, inkey=("reference", "target_mask", "target_box"), outkey=("reference", "target_mask", "target_box"))' +
  ''.join([f'|copy("sentinel2", "query{i}")' for i in range(n_queries)]) +
  f'|inception_crop_with_mask(({query_rand_res}, {query_rand_res}), 32, 100, ({query_rand_mask_res}, {query_rand_mask_res}), inkey=("query0", "target_mask", "target_box"), outkey=("query0", "query0_mask", "query0_box"))' +
  ''.join([
            f'|inception_crop_with_mask(({query_foc_res}, {query_foc_res}), 5, 32, ({query_foc_mask_res}, {query_foc_mask_res}), inkey=("query{i}", "target_mask", "target_box"), outkey=("query{i}", "query{i}_mask", "query{i}_box"))'
            for i in range(1, n_queries)]) +
  ''.join([f'|flip_with_mask(inkey=("query{i}", "query{i}_mask"), outkey=("query{i}", "query{i}_mask"))' for i in
           range(n_queries)]) +
  '|keep("reference"' + ''.join(
[f', "query{i}", "query{i}_box", "query{i}_mask"' for i in range(n_queries)]) + ', "is_l2a")')

compute_loca_flops()

ref: 0.240
q_rand: 0.044
q_foc: 0.719
total: 1.003


## Sen2 + Sen1 early summation

In [10]:
from configs.loca_config import get_config
config = get_config()
config.batch_size = 1

config.multimodal = 'early_fuse_s1_to_rgbn'
config.sen2changroups = ((1, 2, 3, 7), (4, 5, 6, 8), (10, 11))
config.sen2grouped_maintain_seqlen = False
config.use_same_group_attn_mask = False

compute_loca_flops()

Entering jdb:


(jdb)  x.shape


(1, 588, 384)


(jdb)  c


Entering jdb:


(jdb)  c


Entering jdb:


(jdb)  c


Entering jdb:


(jdb)  c


ref: 1.814
q_rand: 0.271
q_foc: 2.163
total: 4.247


## Sen2 + Sen1 early concat

In [8]:
from configs.loca_config import get_config
config = get_config()
config.batch_size = 1

config.multimodal = 'early_concat_s2_and_s1'
config.sen2grouped_maintain_seqlen = False
config.use_same_group_attn_mask = False
config.sen2changroups = ((1, 2, 3, 7), (4, 5, 6, 8), (10, 11),# sen2
                           (12, 16), (13, 17))

compute_loca_flops()

ref: 4.848
q_rand: 0.684
q_foc: 3.622
total: 9.154


## Sen2 + Sen1 early concat with group sampling

In [10]:
from configs.loca_config import get_config
config = get_config()
config.batch_size = 1

config.multimodal = 'early_concat_s2_and_s1'
config.sen2grouped_maintain_seqlen = True
config.use_same_group_attn_mask = False
config.sen2changroups = ((1, 2, 3, 7), (4, 5, 6, 8), (10, 11),# sen2
                           (12, 16), (13, 17))
config.changroups_sampling_weights = (2, 2, 2, 3, 3)

compute_loca_flops()

ref: 0.240
q_rand: 0.044
q_foc: 0.720
total: 1.004


## Sen2 + Sen1 early concat with group sampling and same-group attention masking

In [11]:
from configs.loca_config import get_config
config = get_config()
config.batch_size = 1

config.multimodal = 'early_concat_s2_and_s1'
config.sen2grouped_maintain_seqlen = True
config.use_same_group_attn_mask = True
config.sen2changroups = ((1, 2, 3, 7), (4, 5, 6, 8), (10, 11),# sen2
                           (12, 16), (13, 17))
config.changroups_sampling_weights = (2, 2, 2, 3, 3)

compute_loca_flops()

ref: 0.242
q_rand: 0.044
q_foc: 0.721
total: 1.007


## Sen2 + Sen1 + DEM early concat

In [19]:
from configs.loca_config import get_config
config = get_config()
config.batch_size = 1

config.multimodal = 'early_concat_s2_s1_dem'
config.sen2grouped_maintain_seqlen = True
config.use_same_group_attn_mask = True
config.sen2changroups = ((1, 2, 3, 7), (4, 5, 6, 8), (10, 11),# sen2
                           (12, 16), (13, 17),
                        (20,))
config.changroups_sampling_weights = (2, 2, 2, 3, 3, 6)

compute_loca_flops()

ref: 0.242
q_rand: 0.044
q_foc: 0.721
total: 1.007


## Sen2 + Sen1 + DEM early concat ref masking 60%

In [18]:
config = get_config()
config.batch_size = 1

config.multimodal = 'early_concat_s2_s1_dem'
config.sen2grouped_maintain_seqlen = True
config.use_same_group_attn_mask = True
config.sen2changroups = ((1, 2, 3, 7), (4, 5, 6, 8), (10, 11),# sen2
                           (12, 16), (13, 17),
                        (20,))
config.changroups_sampling_weights = (2, 2, 2, 3, 3, 6)
config.reference_seqlen = int(0.2 * config.n_ref_positions)

compute_loca_flops()

ref: 0.242
q_rand: 0.044
q_foc: 0.721
total: 1.007
