In [1]:
import json
from pathlib import Path

data_root = Path('/home/admin/john/data/mmearth')

splits_path = data_root / "data_1M_v001_64_splits.json"
indices = json.load(open(splits_path, "r"))["train"]
len(indices)

1240526

In [2]:
import h5py

data_path = data_root / 'data_1M_v001_64.h5'
data_full = h5py.File(data_path, 'r')

In [3]:
tile_info_path = data_root / 'data_1M_v001_64_tile_info.json'
with open(tile_info_path, "r") as f:
    tile_info = json.load(f)

band_stats_path = data_root / 'data_1M_v001_64_band_stats.json'
with open(band_stats_path, "r") as f:
    band_stats = json.load(f)

In [18]:
idx = 0
name = data_full['metadata'][indices[idx]][0].decode("utf-8")
data_full['sentinel2'][indices[idx], :, ...].dtype

dtype('uint16')

In [8]:
norm_stats = band_stats

In [9]:
MODALITIES = {
    "sentinel2": [
        "B1",
        "B2",
        "B3",
        "B4",
        "B5",
        "B6",
        "B7",
        "B8A",
        "B8",
        "B9",
        "B11",
        "B12",
    ],
    "sentinel1": "all",
}

MODALITIES_FULL = {
    "sentinel2": [
        "B1",
        "B2",
        "B3",
        "B4",
        "B5",
        "B6",
        "B7",
        "B8A",
        "B8",
        "B9",
        "B10",
        "B11",
        "B12",
    ],
    "sentinel2_cloudmask": ["QA60"],
    "sentinel2_cloudprod": ["MSK_CLDPRB"],
    "sentinel2_scl": ["SCL"],
    "sentinel1": [
        "asc_VV",
        "asc_VH",
        "asc_HH",
        "asc_HV",
        "desc_VV",
        "desc_VH",
        "desc_HH",
        "desc_HV",
    ],
    "aster": ["elevation", "slope"],
    "era5": [
        "prev_month_avg_temp",
        "prev_month_min_temp",
        "prev_month_max_temp",
        "prev_month_total_precip",
        "curr_month_avg_temp",
        "curr_month_min_temp",
        "curr_month_max_temp",
        "curr_month_total_precip",
        "year_avg_temp",
        "year_min_temp",
        "year_max_temp",
        "year_total_precip",
    ],
    "dynamic_world": ["landcover"],
    "canopy_height_eth": ["height", "std"],
    "lat": ["sin", "cos"],
    "lon": ["sin", "cos"],
    "biome": ["biome"],
    "eco_region": ["eco_region"],
    "month": ["sin_month", "cos_month"],
    "esa_worldcover": ["map"],
}

NO_DATA_VAL = {
    "sentinel2": 0,
    "sentinel2_cloudmask": 65535,
    "sentinel2_cloudprod": 65535,
    "sentinel2_scl": 255,
    "sentinel1": float("-inf"),
    "aster": float("-inf"),
    "canopy_height_eth": 255,
    "dynamic_world": 0,
    "esa_worldcover": 255,
    "lat": float("-inf"),
    "lon": float("-inf"),
    "month": float("-inf"),
    "era5": float("inf"),
    "biome": 255,
    "eco_region": 65535,
}

modalities_full = MODALITIES_FULL
modalities = MODALITIES

In [10]:
import numpy as np
from collections import OrderedDict

idx = 0
return_dict = OrderedDict()
name = data_full['metadata'][indices[idx]][0].decode("utf-8")
l2a = tile_info[name]["S2_type"] == "l2a"

for modality in modalities.keys():
    print("modality", modality)
    if modalities[modality] == "all":
        modality_idx = [i for i in range(len(modalities_full[modality]))]
    else:
        modality_idx = [
            modalities_full[modality].index(m)
            for m in modalities[modality]
        ]
    print('modality_idx', modality_idx)

    data = data_full[modality][indices[idx], modality_idx, ...]
    data = np.array(data)
    print('data', data.shape, data[0, 0, :5])

    if modality == "sentinel2":
        modality_ = "sentinel2_l2a" if l2a else "sentinel2_l1c"
    else:
        modality_ = modality

    if modality not in ["biome", "eco_region", "dynamic_world", "esa_worldcover"]:
        means = np.array(norm_stats[modality_]["mean"])[modality_idx]
        stds = np.array(norm_stats[modality_]["std"])[modality_idx]
        if modality in ["era5", "lat", "lon", "month"]:
            # single value mean and std
            data = (data - means) / stds
        else:
            # single value mean and std for each band
            data = (data - means[:, None, None]) / stds[:, None, None]
    print('data', data.shape, data[0, 0, :5])

    data = (
        np.where(data == NO_DATA_VAL[modality], np.nan, data)
        if modality != "dynamic_world"
        else data
    )

    data = data.astype(np.dtype("float32"))

    return_dict[modality] = data

    print()

return_dict['sentinel2'].shape

modality sentinel2
modality_idx [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12]
data (12, 64, 64) [1627 1625 1629 1640 1650]
data (12, 64, 64) [-0.15649307 -0.1578088  -0.15517734 -0.14794082 -0.14136217]

modality sentinel1
modality_idx [0, 1, 2, 3, 4, 5, 6, 7]
data (8, 64, 64) [-7.246237  -8.437068  -9.1128235 -9.794962  -9.71821  ]
data (8, 64, 64) [0.86320915 0.6310232  0.49926571 0.36626369 0.38122859]



(12, 64, 64)

# TFDS Builder

In [11]:
MODALITIES = {
    "sentinel2": [
        "B1",
        "B2",
        "B3",
        "B4",
        "B5",
        "B6",
        "B7",
        "B8A",
        "B8",
        "B9",
        "B11",
        "B12",
    ],
    "sentinel1": "all",
}

MODALITIES_FULL = {
    "sentinel2": [
        "B1",
        "B2",
        "B3",
        "B4",
        "B5",
        "B6",
        "B7",
        "B8A",
        "B8",
        "B9",
        "B10",
        "B11",
        "B12",
    ],
    "sentinel2_cloudmask": ["QA60"],
    "sentinel2_cloudprod": ["MSK_CLDPRB"],
    "sentinel2_scl": ["SCL"],
    "sentinel1": [
        "asc_VV",
        "asc_VH",
        "asc_HH",
        "asc_HV",
        "desc_VV",
        "desc_VH",
        "desc_HH",
        "desc_HV",
    ],
    "aster": ["elevation", "slope"],
    "era5": [
        "prev_month_avg_temp",
        "prev_month_min_temp",
        "prev_month_max_temp",
        "prev_month_total_precip",
        "curr_month_avg_temp",
        "curr_month_min_temp",
        "curr_month_max_temp",
        "curr_month_total_precip",
        "year_avg_temp",
        "year_min_temp",
        "year_max_temp",
        "year_total_precip",
    ],
    "dynamic_world": ["landcover"],
    "canopy_height_eth": ["height", "std"],
    "lat": ["sin", "cos"],
    "lon": ["sin", "cos"],
    "biome": ["biome"],
    "eco_region": ["eco_region"],
    "month": ["sin_month", "cos_month"],
    "esa_worldcover": ["map"],
}

NO_DATA_VAL = {
    "sentinel2": 0,
    "sentinel2_cloudmask": 65535,
    "sentinel2_cloudprod": 65535,
    "sentinel2_scl": 255,
    "sentinel1": float("-inf"),
    "aster": float("-inf"),
    "canopy_height_eth": 255,
    "dynamic_world": 0,
    "esa_worldcover": 255,
    "lat": float("-inf"),
    "lon": float("-inf"),
    "month": float("-inf"),
    "era5": float("inf"),
    "biome": 255,
    "eco_region": 65535,
}

In [12]:
import tensorflow_datasets as tfds
import numpy as np

class MMEarthBuilder(tfds.core.GeneratorBasedBuilder):
    VERSION = tfds.core.Version('0.0.1')
    
    def __init__(self, modalities: dict, **kwargs):
        super().__init__(**kwargs)
        self.modalities = modalities

    def _info(self):
        return tfds.core.DatasetInfo(
            builder=self,
            features=tfds.features.FeaturesDict({
                'sentinel2': tfds.features.Tensor(shape=(12, 64, 64), dtype=np.dtype("float32")),
                'sentinel1': tfds.features.Tensor(shape=(8, 64, 64), dtype=np.dtype("float32")),
                'id': tfds.features.Text(),
            }),
        )

    def _split_generators(self, dl_manager):
        data_root = Path('/home/admin/john/data/mmearth')

        # Full data
        data_path = data_root / 'data_1M_v001_64.h5'
        data_full = h5py.File(data_path, 'r')

        # Split indices
        splits_path = data_root / 'data_1M_v001_64_splits.json'
        with open(splits_path, "r") as f:
            indices = json.load(f)["train"][:10000]

        # Tile info
        tile_info_path = data_root / 'data_1M_v001_64_tile_info.json'
        with open(tile_info_path, "r") as f:
            tile_info = json.load(f)

        # Band norm stats
        band_stats_path = data_root / 'data_1M_v001_64_band_stats.json'
        with open(band_stats_path, "r") as f:
            norm_stats = json.load(f)
        
        return {
            'train': self._generate_examples(data_full, indices, tile_info, norm_stats)
        }

    def _generate_examples(self, data_full, indices, tile_info, norm_stats):
        for idx in indices:
            return_dict = OrderedDict()
            name = data_full['metadata'][idx][0].decode("utf-8")
            l2a = tile_info[name]["S2_type"] == "l2a"

            for modality in self.modalities.keys():
                # Get band indices
                if self.modalities[modality] == "all":
                    modality_idx = [i for i in range(len(MODALITIES_FULL[modality]))]
                else:
                    modality_idx = [MODALITIES_FULL[modality].index(m) for m in self.modalities[modality]]

                # Get data
                data = data_full[modality][idx, modality_idx, ...]
                data = np.array(data)

                # inside the band_stats, the name for sentinel2 is sentinel2_l1c or sentinel2_l2a
                if modality == "sentinel2":
                    modality_ = "sentinel2_l2a" if l2a else "sentinel2_l1c"
                else:
                    modality_ = modality

                means = np.array(norm_stats[modality_]["mean"])[modality_idx]
                stds = np.array(norm_stats[modality_]["std"])[modality_idx]
                data = (data - means[:, None, None]) / stds[:, None, None]  # Why the `None`s

                # converting the nodata values to nan to keep everything consistent
                data = (
                    np.where(data == NO_DATA_VAL[modality], np.nan, data)
                    if modality != "dynamic_world"
                    else data
                )

                data = data.astype(np.dtype("float32"))

                return_dict[modality] = data

            return_dict["id"] = name

            yield name, return_dict

In [13]:
builder = MMEarthBuilder(modalities=MODALITIES)
builder.download_and_prepare(
    download_dir='/home/admin/john/data/mmearth_',
    download_config=tfds.download.DownloadConfig(manual_dir='/home/admin/john/data/mmearth')
)
builder.as_dataset()

{'train': <_PrefetchDataset element_spec={'id': TensorSpec(shape=(), dtype=tf.string, name=None), 'sentinel1': TensorSpec(shape=(8, 64, 64), dtype=tf.float32, name=None), 'sentinel2': TensorSpec(shape=(12, 64, 64), dtype=tf.float32, name=None)}>}

In [1]:
from mmearth.mmearth_dataset import MMEarthBuilder
import tensorflow_datasets as tfds
from scenic.dataset_lib import dataset_utils
import functools
from scenic.dataset_lib.big_transfer import builder
from configs.loca_mmearth64_tiny16 import get_config
import ops
import tensorflow as tf

MODALITIES = {
    "sentinel2": [
        "B1",
        "B2",
        "B3",
        "B4",
        "B5",
        "B6",
        "B7",
        "B8A",
        "B8",
        "B9",
        "B11",
        "B12",
    ],
    "sentinel1": "all",
}

config = get_config()

mmearth_builder = MMEarthBuilder()
mmearth_builder.download_and_prepare()
dataset = mmearth_builder.as_dataset(
    split="train",
    shuffle_files=True,
    read_config=tfds.ReadConfig(
        skip_prefetch=True,  # We prefetch after pipeline.
        try_autocache=False,  # We control this, esp. for few-shot.
        add_tfds_id=True,
    ),
)

2025-02-25 16:04:08.259675: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-02-25 16:04:08.259759: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-02-25 16:04:08.259786: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.
Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). 

For more information see: https://gi

[1mDownloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to /home/admin/tensorflow_datasets/mm_earth_builder/0.0.3...[0m


  from .autonotebook import tqdm as notebook_tqdm
Generating splits...:   0%|          | 0/1 [00:00<?, ? splits/s]
[Aerating train examples...: 0 examples [00:00, ? examples/s]
[Aerating train examples...: 69 examples [00:01, 68.39 examples/s]
[Aerating train examples...: 141 examples [00:02, 70.36 examples/s]
[Aerating train examples...: 212 examples [00:03, 69.84 examples/s]
[Aerating train examples...: 282 examples [00:04, 69.49 examples/s]
[Aerating train examples...: 352 examples [00:05, 68.68 examples/s]
[Aerating train examples...: 421 examples [00:06, 68.66 examples/s]
[Aerating train examples...: 490 examples [00:07, 68.52 examples/s]
[Aerating train examples...: 559 examples [00:08, 68.40 examples/s]
[Aerating train examples...: 628 examples [00:09, 68.09 examples/s]
[Aerating train examples...: 697 examples [00:10, 67.62 examples/s]
[Aerating train examples...: 765 examples [00:11, 67.43 examples/s]
[Aerating train examples...: 833 examples [00:12, 67.27 example

[1mDataset mm_earth_builder downloaded and prepared to /home/admin/tensorflow_datasets/mm_earth_builder/0.0.3. Subsequent calls will reuse this data.[0m
Instructions for updating:
Use `tf.data.Dataset.counter(...)` instead.


Instructions for updating:
Use `tf.data.Dataset.counter(...)` instead.


In [None]:
dataset = dataset_utils.make_pipeline(
    data=dataset,
    preprocess_fn=builder.get_preprocess_fn(config.dataset_configs.pp_train),
    batch_size=1,
    drop_remainder=True,
    cache=False,
    repeats=None,
    prefetch=config.dataset_configs.get('prefetch_to_host', 2),
    shuffle_buffer_size=config.dataset_configs.shuffle_buffer_size,
    repeat_after_batching=False,
    ignore_errors=True,)

n_train_ex = dataset_utils.get_num_examples("mm_earth_builder", "train")

shard_batches = functools.partial(dataset_utils.shard, n_devices=1)

train_iter = iter(dataset)
train_iter = map(dataset_utils.tf_to_numpy, train_iter)
train_iter = map(shard_batches, train_iter)
next(train_iter)

In [6]:
print(MMEarthBuilder)
print(MMEarthBuilder(MODALITIES))

<class 'mmearth.mmearth_dataset.MMEarthBuilder'>
<mmearth.mmearth_dataset.MMEarthBuilder object at 0x7ff4742409d0>


In [10]:
from imagenet import Imagenet

builder = Imagenet()
builder.download_and_prepare(
    download_dir='/home/admin/john/data/imagenet2012',
    download_config=tfds.download.DownloadConfig(
        manual_dir='/home/admin/john/data/ImageNet')
)

[1mDownloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to /home/admin/tensorflow_datasets/imagenet/5.1.0...[0m


Generating splits...:   0%|          | 0/1 [00:00<?, ? splits/s]
[Aerating validation examples...: 0 examples [00:00, ? examples/s]
[Aerating validation examples...: 4514 examples [00:01, 4513.37 examples/s]
[Aerating validation examples...: 9028 examples [00:02, 3106.51 examples/s]
[Aerating validation examples...: 12387 examples [00:03, 3040.16 examples/s]
[Aerating validation examples...: 15555 examples [00:04, 3034.70 examples/s]
[Aerating validation examples...: 18666 examples [00:06, 2782.89 examples/s]
[Aerating validation examples...: 21511 examples [00:07, 2753.88 examples/s]
[Aerating validation examples...: 24303 examples [00:08, 2609.02 examples/s]
[Aerating validation examples...: 26939 examples [00:09, 2555.05 examples/s]
[Aerating validation examples...: 29509 examples [00:11, 1932.23 examples/s]
[Aerating validation examples...: 31966 examples [00:12, 2054.51 examples/s]
[Aerating validation examples...: 34820 examples [00:13, 2254.38 examples/s]
[Aerating 

[1mDataset imagenet downloaded and prepared to /home/admin/tensorflow_datasets/imagenet/5.1.0. Subsequent calls will reuse this data.[0m




In [3]:
import loca_dataset
from scenic.train_lib import train_utils
from configs.loca_mmearth64_tiny16 import get_config
import jax
import ops

config = get_config()
rng = jax.random.key(77)
data_rng, rng = jax.random.split(rng)

dataset = train_utils.get_dataset(
      config, data_rng)

aldkflskf
aldkflskf
train_ds <_PrefetchDataset element_spec={'reference': TensorSpec(shape=(4, 96, 96, 12), dtype=tf.float32, name=None), 'query0': TensorSpec(shape=(4, 96, 96, 12), dtype=tf.float32, name=None), 'query0_mask': TensorSpec(shape=(4, 12, 12, 1), dtype=tf.int32, name=None), 'query0_box': TensorSpec(shape=(4, 5), dtype=tf.float32, name=None)}>


In [None]:
next(dataset.train_iter)

In [1]:
import loca_dataset
from scenic.train_lib import train_utils
import jax
import ops
from configs.loca_imnet1k_base16 import get_config

config = get_config()
rng = jax.random.key(77)
data_rng, rng = jax.random.split(rng)

dataset = train_utils.get_dataset(
      config, data_rng)


2025-02-24 14:31:41.138524: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-02-24 14:31:41.138597: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-02-24 14:31:41.140685: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.
Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). 

For more information see: https://gi

Instructions for updating:
Use `tf.data.Dataset.counter(...)` instead.


Instructions for updating:
Use `tf.data.Dataset.counter(...)` instead.


aldkflskf
aldkflskf
Instructions for updating:
Use `tf.data.Dataset.ignore_errors` instead.


Instructions for updating:
Use `tf.data.Dataset.ignore_errors` instead.


train_ds <_PrefetchDataset element_spec={'reference': TensorSpec(shape=(4, 224, 224, 3), dtype=tf.float32, name=None), 'query0': TensorSpec(shape=(4, 224, 224, 3), dtype=tf.float32, name=None), 'query0_mask': TensorSpec(shape=(4, 14, 14, 1), dtype=tf.int32, name=None), 'query0_box': TensorSpec(shape=(4, 5), dtype=tf.float32, name=None)}>


In [2]:
next(dataset.train_iter)

{'query0': Array([[[[[ 1.718119  ,  1.9067225 ,  2.0804021 ],
           [ 1.718119  ,  1.9067225 ,  2.0804021 ],
           [ 1.718119  ,  1.9067225 ,  2.0804021 ],
           ...,
           [-0.33690262, -0.07591756,  0.10612051],
           [-0.39402273, -0.13431281,  0.04798479],
           [-0.43750057, -0.17876111,  0.00373416]],
 
          [[ 1.718119  ,  1.9067225 ,  2.0804021 ],
           [ 1.718119  ,  1.9067225 ,  2.0804021 ],
           [ 1.718119  ,  1.9067225 ,  2.0804021 ],
           ...,
           [-0.37087572, -0.1242025 ,  0.06526444],
           [-0.39041173, -0.14417459,  0.04538113],
           [-0.42343867, -0.17793888,  0.01176702]],
 
          [[ 1.718119  ,  1.9067225 ,  2.0804021 ],
           [ 1.718119  ,  1.9067225 ,  2.0804021 ],
           [ 1.718119  ,  1.9067225 ,  2.0804021 ],
           ...,
           [-0.40259242, -0.18236293,  0.02268434],
           [-0.41035587, -0.19125786,  0.01306706],
           [-0.42678162, -0.19161282,  0.0036131 ]],