# Dataloader Example


This script demonstrates how to efficiently load and iterate over a satellite Earth observation dataset stored in the Zarr format, using a modular PyTorch DataLoader setup. 

In [2]:
import zarr
from zarr.storage import DirectoryStore
import fsspec

zarr_path = "/Data/phisatnet_clouds/phisatnet_clouds.zarr/zarr"

# 1. Point at the on‑disk zarr directory
store = DirectoryStore(zarr_path)
root = zarr.open_group(store=store, mode="r")

# Print the group tree to explore its structure
print(root.tree())


GroupNotFoundError: group not found at path ''

In [None]:
import zarr
from zarr.storage import DirectoryStore

store = DirectoryStore("/Data/phisatnet_clouds")
root = zarr.open(store=store, mode="r", zarr_format=3)
dataset_group = root["trainval"]

sample_id = sorted(dataset_group.keys())[0]
sample = dataset_group[sample_id]
img = sample["img"][:]
label = sample["label"][:]
print(f"image shape: {img.shape}, label shape: {label.shape}")
# Print all metadata attributes
print(f"Attributes for sample '{sample_id}':")
for key, value in sample_group.attrs.items():
    print(f"  {key}: {value}")

PathNotFoundError: nothing found at path ''

In [3]:
import zarr
from zarr.storage import DirectoryStore

zarr_path = "/Data/worldfloods/worldfloods.zarr"

root = zarr.open(zarr_path, mode='r')

# Choose the split ("trainval" or "test")
dataset_set = "trainval"
dataset_group = root[dataset_set]

# Access the sample group
sample_ids = sorted(dataset_group.keys())

sample_id = sample_ids[0]
sample_group = dataset_group[sample_id]

# Load image
img = sample_group['img'][:]
label = sample_group['label'][:]
print(f"image shape: {img.shape}, label shape: {label.shape}")
# Print all metadata attributes
print(f"Attributes for sample '{sample_id}':")
for key, value in sample_group.attrs.items():
    print(f"  {key}: {value}")

image shape: (8, 512, 512), label shape: (1, 512, 512)
Attributes for sample '0000000':
  cloud_cover: nan
  crs: EPSG:32629
  datatake: 21-10-2017 11:54:57
  geolocation: {'LL': [nan, nan], 'LR': [nan, nan], 'UL': [nan, nan], 'UR': [nan, nan]}
  sensor: S2A
  sensor_orbit: ASCENDING
  sensor_orbit_number: 0
  sensor_resolution: 4.75
  spectral_bands_ordered: B02-B03-B04-B08-B05-B06-B07-PAN
  sun_azimuth: nan
  sun_elevation: nan
  task: segmentation
  view_azimuth: nan
  view_elevation: nan


In [1]:
import zarr
zarr_path = "/Data/lpl_burned_area/burned.zarr" #"/Data/fire_dataset/fire_dataset.zarr"
root = zarr.open(zarr_path, mode='r')

# Choose the split ("trainval" or "test")
dataset_set = "trainval"
dataset_group = root[dataset_set]

# Pick a sample ID (e.g., first one)
sample_ids = sorted(dataset_group.keys())
sample_id = sample_ids[0]  # or any other valid index

# Access the sample group
sample_group = dataset_group[sample_id]

# Print all metadata attributes
print(f"Attributes for sample '{sample_id}':")
for key, value in sample_group.attrs.items():
    print(f"  {key}: {value}")

Attributes for sample '0000000':
  cloud_cover: nan
  datatake: 00-00-0000 00:00:00
  geolocation: {'LL': [nan, nan], 'LR': [nan, nan], 'UL': [nan, nan], 'UR': [nan, nan]}
  sensor: S2A
  sensor_orbit: ASCENDING
  sensor_orbit_number: 0
  sensor_resolution: 10
  spectral_bands_ordered: B2-B3-B4-B4
  sun_azimuth: nan
  sun_elevation: nan
  task: segmentation
  view_azimuth: nan
  view_elevation: nan


In [4]:
from typing import List, Dict
def collate_fn(batch: List[Dict]) -> Dict:
    """
    Custom collate function to handle different task types and metadata.
    
    Args:
        batch: List of sample dictionaries from PhiSatDataset
        
    Returns:
        Dictionary with batched tensors and metadata
    """
    # Group samples by task to handle different label shapes
    task_groups = {}
    for sample in batch:
        task = sample['task']
        if task not in task_groups:
            task_groups[task] = []
        task_groups[task].append(sample)
    
    result = {}
    
    # Process each task group separately
    for task, samples in task_groups.items():
        # Get all keys from the first sample
        keys = samples[0].keys()
        
        for key in keys:
            # Skip task and sample_id for batching
            if key in ['task', 'sample_id']:
                continue
                
            # Handle tensors
            if isinstance(samples[0][key], torch.Tensor):
                # Stack tensors with same shapes
                try:
                    result[f"{task}_{key}"] = torch.stack([s[key] for s in samples])
                except RuntimeError:
                    # If tensors have different shapes, return as list
                    result[f"{task}_{key}"] = [s[key] for s in samples]
            else:
                # For non-tensor data, collect as list
                result[f"{task}_{key}"] = [s[key] for s in samples]
        
        # Store sample IDs
        result[f"{task}_sample_ids"] = [s['sample_id'] for s in samples]
        
    # Store task information
    result['tasks'] = list(task_groups.keys())
    result['task_counts'] = {task: len(samples) for task, samples in task_groups.items()}
    
    return result

In [1]:
from data_loader import get_zarr_dataloader, NormalizeChannels
from tqdm import tqdm
import torch
import numpy as np 

# Path to the input Zarr dataset
zarr_path = "/Data/lpl_burned_area/burned.zarr"
# Select dataset split: "trainval" or "test"
dataset_set = "trainval"

# Step 1: Compute dataset-wide per-band mean and std
print("Computing mean and std across dataset...")
_, _, dataloader = get_zarr_dataloader(
    zarr_path=zarr_path,
    dataset_set=dataset_set,
    batch_size=16,
    shuffle=False,
    num_workers=4,
    task_filter="segmentation",
    metadata_keys=["sensor", "timestamp", "geolocation", "crs"],
    num_classes=4
)

sum_ = 0
sum_sq = 0
total_pixels = 0

for batch in tqdm(dataloader, desc="Computing stats"):
    for task in batch['tasks']:
        images = np.array(batch[f'{task}_img'])  # shape: (B, H, W, C)
        if images.ndim != 4:
            raise ValueError("Expected image tensor of shape (B, H, W, C)")
        batch_size, height, width, _ = images.shape
        pixels_in_batch = batch_size * height * width

        sum_ += images.sum(axis=(0, 1, 2))
        sum_sq += (images ** 2).sum(axis=(0, 1, 2))
        total_pixels += pixels_in_batch

mean = sum_ / total_pixels
std = np.sqrt((sum_sq / total_pixels) - (mean ** 2))

print("Mean per band:", mean.tolist())
print("Stddev per band:", std.tolist())

Computing mean and std across dataset...
Dataset trainval shapes: img=(256, 256, 7), label=(256, 256, 4)


Computing stats: 100%|████████████████████████████████████████████████████████████████████████| 487/487 [01:06<00:00,  7.30it/s]

Mean per band: [0.5692603492540789, 0.5233146455770651, 0.49774728208504626, 0.5614061973077787, 0.5094977101466148, 0.5503450336828751, 0.5719299002762076]
Stddev per band: [0.24279108867296925, 0.25451220952717407, 0.277410560398893, 0.28924007207410934, 0.2766535835665443, 0.2841112679453489, 0.28949325669342035]





In [4]:
from data_loader2 import get_zarr_dataloader, NormalizeChannels
from tqdm import tqdm
import torch
import numpy as np 

# Path to the input Zarr dataset
zarr_path = "/Data/worldfloods/worldfloods.zarr"
# Select dataset split: "trainval" or "test"
dataset_set = "trainval"

# Step 1: Compute dataset-wide per-band mean and std
print("Computing mean and std across dataset...")
#_, _, dataloader = get_zarr_dataloader(
#    zarr_path=zarr_path,
#    dataset_set=dataset_set,
#    batch_size=16,
#    shuffle=False,
#    num_workers=4,
#    task_filter="segmentation",
#    metadata_keys=["sensor", "timestamp", "geolocation", "crs"],
#    num_classes=4
#)

#sum_ = 0
#sum_sq = 0
#total_pixels = 0

#for batch in tqdm(dataloader, desc="Computing stats"):
#    for task in batch['tasks']:
#        try:
#            images = np.array(batch[f'{task}_img'])  # shape: (B, H, W, C)
#            if images.ndim != 4:
#                raise ValueError("Expected image tensor of shape (B, H, W, C)")
#            batch_size, height, width, _ = images.shape
#            pixels_in_batch = batch_size * height * width
#    
#            sum_ += images.sum(axis=(0, 1, 2))
#            sum_sq += (images ** 2).sum(axis=(0, 1, 2))
#            total_pixels += pixels_in_batch
#        except Exception as e:
#            print(e)
#            continue

mean = sum_ / total_pixels
std = np.sqrt((sum_sq / total_pixels) - (mean ** 2))

print("Mean per band:", mean.tolist())
print("Stddev per band:", std.tolist())

Computing mean and std across dataset...
Mean per band: [0.16345226170528732, 0.1485720135879921, 0.14293998321824464, 0.16128031952130834, 0.24135972282919374, 0.1575786149640857, 0.2155737770703539, 0.2502169078127071]
Stddev per band: [0.28859553480558336, 0.3297414341820656, 0.3394307197049047, 0.17330132528004097, 0.2588063781866027, 0.40266684228798705, 0.40724442981311126, 0.3433434216321024]


In [None]:
from data_loader2 import get_zarr_dataloader, NormalizeChannels
from tqdm import tqdm
import torch
import zarr

# Path to the input Zarr dataset
#zarr_path = "/Data/fire_dataset/fire_dataset.zarr"
zarr_path = "/Data/worldfloods/worldfloods.zarr"
# Select dataset split: "trainval" or "test"
dataset_set = "trainval"
#zarr.open(zarr_path)
# Initialize a PyTorch DataLoader from a Zarr-based dataset
_, _, dataloader = get_zarr_dataloader(
    zarr_path=zarr_path,                     # Path to the Zarr archive
    dataset_set=dataset_set,                 # Dataset subset to use
    batch_size=16,                           # Number of samples per batch
    shuffle=True,                            # Enable shuffling (useful for training)
    num_workers=4,                           # Number of parallel workers for loading
    #transform=NormalizeChannels(min_max=True),  # Normalize input channels to [0, 1]
    task_filter="segmentation",              # Only load data for the "segmentation" task
    metadata_keys=["sensor", "timestamp", "geolocation", "crs"],   # Include auxiliary metadata fields
)


all_unique_labels = set()

try:
    for idx, batch in enumerate(tqdm(dataloader, desc="Processing Batches")):
        for task in batch['tasks']:
            labels = batch[f'{task}_label']  # Might be shape (B, H, W) or list of scalars
    
            # Case 1: If labels is a tensor (e.g. B x H x W)
            if isinstance(labels, torch.Tensor):
                unique_vals = torch.unique(labels)
                all_unique_labels.update(unique_vals.cpu().numpy().tolist())
    
            # Case 2: If labels is a list/array of scalars
            elif isinstance(labels, (list, tuple)):
                for label in labels:
                    if isinstance(label, torch.Tensor):
                        unique_vals = torch.unique(label)
                    else:
                        # If label is scalar (e.g. float32), wrap in tensor first
                        label_tensor = torch.tensor(label)
                        unique_vals = torch.unique(label_tensor)
    
                    all_unique_labels.update(unique_vals.cpu().numpy().tolist())
    
            else:
                raise TypeError(f"Unexpected label type: {type(labels)}")
except Exception as e:
    print(e)

# Final result
print(f"\nAll unique label values seen across all batches and tasks: {sorted(all_unique_labels)}")
print(f"Total number of classes: {len(all_unique_labels)}")

Dataset trainval shapes: img=(512, 512, 8), label=(512, 512, 1)
weights: tensor([2.2159, 0.6457]), pos_weights:tensor([4.0633, 0.4754])


Processing Batches:   6%|█                 | 228/3773 [04:35<1:31:33,  1.55s/it]