In [1]:
import sys
import warnings
import os
sys.path.append(os.path.join(os.getcwd(), "model/"))
# sys.path.append("..")
warnings.filterwarnings(action="ignore")

In [2]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from einops import rearrange

from model.src.datamodule import ClayDataModule
from model.src.model import ClayMAEModule

In [19]:
DATA_DIR = "/home/ubuntu/data"
CHECKPOINT_PATH = "model/checkpoints/v1/clay-v1-base.ckpt"
METADATA_PATH = "model/configs/metadata.yaml"
CHIP_SIZE = 224
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [17]:
# import timm
# print(timm.list_models())


In [20]:
# As we want to visualize the embeddings from the model,
# we neither mask the input image or shuffle the patches
module = ClayMAEModule.load_from_checkpoint(
    checkpoint_path=CHECKPOINT_PATH,
    metadata_path=METADATA_PATH,
    mask_ratio=0.0,
    shuffle=False,
)

module.eval();

model.safetensors:   0%|          | 0.00/343M [00:00<?, ?B/s]

# Preparing Earth Observation Data with `stacchip`

This script demonstrates how to use the `stacchip` library to process Earth observation data from a STAC catalog. It retrieves image chips from the specified collection, organizes them into batches, and saves them in the `.npz` format required by the `ClayDataModule`. 

### Overview:
1. **STAC Catalog Query**: Fetch imagery from collections like NAIP, Landsat, or Sentinel-2.
2. **Chipping and Indexing**: Dynamically generate image chips using `NoStatsChipIndexer` and `Chipper`.
3. **Batching and Saving**: Save image chips in `.npz` format with placeholder metadata.
4. **Integration**: Organize the output directory for seamless use with the `ClayDataModule`.

Replace parameters like `COLLECTION_NAME`, `OUTPUT_DIR`, and `BATCH_SIZE` as needed for your dataset.

In [None]:
import random
import os
import numpy as np
from pathlib import Path
import pystac_client
from stacchip.indexer import NoStatsChipIndexer
from stacchip.chipper import Chipper

# Optimize GDAL settings for cloud-optimized reading
os.environ["GDAL_DISABLE_READDIR_ON_OPEN"] = "EMPTY_DIR"
os.environ["AWS_REQUEST_PAYER"] = "requester"

# Parameters
OUTPUT_DIR = "/path/to/output/directory"
STAC_CATALOG_URL = "https://earth-search.aws.element84.com/v1"
COLLECTION_NAME = "naip"  # Replace with appropriate collection name
MAX_ITEMS = 100  # Number of items to fetch from the STAC catalog
BATCH_SIZE = 128  # Chips per batch
CHIP_SIZE = 256  # Chip resolution (H, W)

# Create the output directory
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)

# Query the STAC catalog
catalog = pystac_client.Client.open(STAC_CATALOG_URL)
items = catalog.search(collections=[COLLECTION_NAME], max_items=MAX_ITEMS).item_collection()
items_list = list(items)
random.shuffle(items_list)  # Shuffle to select random items

# Process each STAC item and create chips
batch_pixels = []
batch_lat_norm = []
batch_lon_norm = []
batch_week_norm = []
batch_hour_norm = []

for item_idx, item in enumerate(items_list):
    print(f"Processing item: {item.id}")

    # Index the chips in the item
    indexer = NoStatsChipIndexer(item)
    chipper = Chipper(indexer, assets=["image"])  # Specify the assets to chip

    # Retrieve chips from the item
    for chip_id in random.sample(range(len(chipper)), 5):  # Adjust chip sampling as needed
        _, _, chip = chipper[chip_id]
        batch_pixels.append(chip["image"])  # Add chip data (B, C, H, W)

        # Add dummy normalized metadata (replace with real values if available)
        lat_norm = np.zeros((1, 2), dtype=np.float32)
        lon_norm = np.zeros((1, 2), dtype=np.float32)
        week_norm = np.zeros((1, 2), dtype=np.float32)
        hour_norm = np.zeros((1, 2), dtype=np.float32)

        batch_lat_norm.append(lat_norm)
        batch_lon_norm.append(lon_norm)
        batch_week_norm.append(week_norm)
        batch_hour_norm.append(hour_norm)

        # Save batch when full or at the end of items
        if len(batch_pixels) == BATCH_SIZE or (item_idx == len(items_list) - 1 and chip_id == 4):
            # Format batch for saving
            batch_pixels = np.stack(batch_pixels)  # (B, C, H, W)
            batch_lat_norm = np.vstack(batch_lat_norm)
            batch_lon_norm = np.vstack(batch_lon_norm)
            batch_week_norm = np.vstack(batch_week_norm)
            batch_hour_norm = np.vstack(batch_hour_norm)

            # Save batch as .npz file
            batch_index = item_idx // BATCH_SIZE
            np.savez(
                Path(OUTPUT_DIR) / f"cube_{batch_index}.npz",
                pixels=batch_pixels,
                lat_norm=batch_lat_norm,
                lon_norm=batch_lon_norm,
                week_norm=batch_week_norm,
                hour_norm=batch_hour_norm,
            )

            # Reset batch
            batch_pixels = []
            batch_lat_norm = []
            batch_lon_norm = []
            batch_week_norm = []
            batch_hour_norm = []

print("Data preparation complete!")


In [None]:
from datamodule import ClayDataModule

DATA_DIR = "/path/to/output/directory"
METADATA_PATH = "model/configs/metadata.yaml"
CHIP_SIZE = 224

dm = ClayDataModule(
    data_dir=DATA_DIR,
    metadata_path=METADATA_PATH,
    size=CHIP_SIZE,
    batch_size=1,
    num_workers=1,
)
dm.setup(stage="fit")