In [1]:
# Common imports and settings
import os, sys, re
from pathlib import Path
from IPython.display import Markdown
import pandas as pd
pd.set_option("display.max_rows", None)
import xarray as xr
import dask
from dask.distributed import Client
from dask_gateway import Gateway
from dateutil.parser import parse
from dateutil.relativedelta import relativedelta

# Datacube
import datacube
from datacube.utils.aws import configure_s3_access
import odc.geo.xr                                  # https://github.com/opendatacube/odc-geo
from datacube.utils import masking  # https://github.com/opendatacube/datacube-core/blob/develop/datacube/utils/masking.py
from odc.algo import enum_to_bool                  # https://github.com/opendatacube/odc-tools/blob/develop/libs/algo/odc/algo/_masking.py
from dea_tools.plotting import display_map, rgb    # https://github.com/GeoscienceAustralia/dea-notebooks/tree/develop/Tools

import boto3

# Basic plots
%matplotlib inline
# import matplotlib.pyplot as plt
# plt.rcParams['figure.figsize'] = [12, 8]

# Holoviews
# https://holoviz.org/tutorial/Composing_Plots.html
# https://holoviews.org/user_guide/Composing_Elements.html
import hvplot.pandas
import hvplot.xarray
import panel as pn
import colorcet as cc
import cartopy.crs as ccrs
from datashader import reductions
from holoviews import opts
# hv.extension('bokeh', logo=False)
print("Libraries loaded successfully.")

Libraries loaded successfully.


In [2]:
# EASI defaults
# These are convenience functions so that the notebooks in this repository work in all EASI deployments

# The `git.Repo()` part returns the local directory that easi-notebooks has been cloned into
# If using the `easi-tools` functions from another path, replace `repo` with your local path to `easi-notebooks` directory
try:
    import git
    repo = git.Repo('.', search_parent_directories=True).working_tree_dir    # Path to this cloned local directory
except (ImportError, git.InvalidGitRepositoryError):
    repo = Path.home() / 'easi-notebooks'    # Reasonable default
    if not repo.is_dir():
        raise RuntimeError('To use `easi-tools` please provide the local path to `https://github.com/csiro-easi/easi-notebooks`')
if repo not in sys.path:
    sys.path.append(str(repo))    # Add the local path to `easi-notebooks` to python

from easi_tools import EasiDefaults
from easi_tools import initialize_dask, xarray_object_size, mostcommon_crs, heading
#from easi_tools.load_s2l2a import load_s2l2a_with_offset
print("EASI librariies loaded successfully.")

EASI librariies loaded successfully.


In [3]:
gateway = Gateway()
options = gateway.cluster_options()

clusters = gateway.list_clusters()
if not clusters:
    print('Creating new cluster...')
    cluster = gateway.new_cluster(cluster_options=options)
else:
    print(f'Connecting to existing cluster: {clusters[0].name}')
    cluster = gateway.connect(clusters[0].name)


cluster.scale(4)
client = cluster.get_client()

print("Cluster Dashboard:", client.dashboard_link)

Creating new cluster...
Cluster Dashboard: https://hub.csiro.easi-eo.solutions/services/dask-gateway/clusters/easihub.4b63a2362434486eab8eefd9ada71cd8/status


In [8]:
easi = EasiDefaults()
dc = datacube.Datacube()
# Get the centroid of the coordinates of the default extents
central_lat = sum(easi.latitude)/2
central_lon = sum(easi.longitude)/2
# central_lat = -42.019
# central_lon = 146.615

# Set the buffer to load around the central coordinates
# This is a radial distance for the bbox to actual area so bbox 2x buffer in both dimensions
buffer = 0.8

# Compute the bounding box for the study area
study_area_lat = (central_lat - buffer, central_lat + buffer)
study_area_lon = (central_lon - buffer, central_lon + buffer)


# Set the date range to load data over
set_time = easi.time
set_time = (set_time[0], parse(set_time[0]) + relativedelta(years=1))
# set_time = ("2021-01-01", "2021-12-31")

# ------------------------ Measurment choice -----------------
# Selected measurement names (used in this notebook). None` will load all of them
alias = easi.aliases('landsat')

target_keys = ['blue', 'green', 'red', 'nir', 'swir1', 'swir2']
qa_band = alias['qa_band']
qa_mask = easi.qa_mask('landsat')

# Set the resampling method for the bands
resampling_data = {qa_band: "nearest", "*": "average"}
resampling_lables = {"*": "nearest"}

# Set the coordinate reference system and output resolution
set_crs = easi.crs('landsat')  # If defined, else None
set_resolution = easi.resolution('landsat')  # If defined, else None

# Set the scene group_by method
group_by = "solar_day"

Successfully found configuration for deployment "csiro"


In [10]:
# --- 1. Define Products ---
# Input (Features)
products = easi.product('landsat')
spectral_measurements = [alias[k] for k in target_keys] # e.g., 'red', 'green', 'nir'
qa_band = alias['qa_band']
measurements = spectral_measurements + [qa_band]

# Labels (Targets)
label_product = 'ga_ls_landcover_class_cyear_3'
label_measurement = 'level3' # We'll use the base classification level

# --- 2. Load Input Data (Includes QA band) ---
dataset = dc.load(
    product=products,
    x=study_area_lon,
    y=study_area_lat,
    time=set_time,
    measurements=measurements,
    resampling={qa_band: "nearest", "*": "average"}, # Resampling for QA is always 'nearest'
    output_crs=set_crs,
    resolution=set_resolution,
    dask_chunks = {"time":1, "x":2048, "y":2048},
    group_by=group_by,
)

# --- 3. Load Label Data (Aligned to Input) ---
labels_ds = dc.load(
    product=label_product,
    x=study_area_lon,
    y=study_area_lat,
    time=set_time,
    measurements=[label_measurement],
    resampling={"*": "nearest"}, # Labels must use nearest neighbor resampling
    output_crs=set_crs,
    resolution=set_resolution,
    dask_chunks = {"time":1, "x":2048, "y":2048},
    group_by=group_by,
)

print(f"Input dataset size (GiB) {dataset.nbytes / 2**30:.2f}")
print(f"Labels dataset size (GiB) {labels_ds.nbytes / 2**30:.2f}")

Input dataset size (GiB) 18.34
Labels dataset size (GiB) 0.07


In [None]:
# --- 1. Define Products ---
# Input (Features)
products = easi.product('landsat')
spectral_measurements = [alias[k] for k in target_keys] # e.g., 'red', 'green', 'nir'
qa_band = alias['qa_band']
measurements = spectral_measurements + [qa_band]

# Labels (Targets)
label_product = 'ga_ls_landcover_class_cyear_3'
label_measurement = 'level3' # We'll use the base classification level

# --- 2. Load Input Data (Includes QA band) ---
dataset = dc.load(
    product=products,
    x=study_area_lon,
    y=study_area_lat,
    time=set_time,
    measurements=measurements,
    resampling=resampling_data,
    output_crs=set_crs,
    resolution=set_resolution,
    dask_chunks = {"time":1, "x":2048, "y":2048},
    group_by=group_by,
)

aligned_times = dataset.time
print(f"Input data successfully loaded. Total scenes found: {len(aligned_times)}")

# --- 3. Load Label Data (Aligned to Input) ---
labels_ds = dc.load(
    product=label_product,
    x=study_area_lon,
    y=study_area_lat,
    time=set_time,
    measurements=[label_measurement],
    resampling={"*": "nearest"}, # Labels must use nearest neighbor resampling
    output_crs=set_crs,
    resolution=set_resolution,
    dask_chunks = {"time":1, "x":2048, "y":2048},
    group_by=group_by,
)
labels_ds = labels_ds.reindex_like(
    dataset, 
    method='nearest' # <-- This is the key!
)
# NOTE: play around with chunk sizze if training doesn't work
# ds_final = ds_final.chunk({'time': 1, 'x': 512, 'y': 512}) # GPU-friendly chunks
# labels_final = labels_final.chunk({'time': 1, 'x': 512, 'y': 512})
aligned_times = labels_ds.time
print(f"Input data successfully loaded. Total scenes found: {len(aligned_times)}")
print(f"Label dataset size (GiB) {dataset.nbytes / 2**30:.2f}")

print(f"Labels dataset size (GiB) {labels_ds.nbytes / 2**30:.2f}")
display(dataset)
display(labels_ds)

In [37]:
# --- Step A: Create Cloud Mask ---
# The mask is generated from the QA band, but it's a Dask array itself.
qa_mask = easi.qa_mask('landsat')
cloud_free_mask = masking.make_mask(dataset[qa_band], **qa_mask)

# --- Step B: Process Input Features (X) ---
# 1. Select only the spectral bands
training_ds = dataset[spectral_measurements]

# 2. Apply Mask: Replace clouds/shadows/etc. with 0.0 (safe 'black')
ds_masked = training_ds.where(cloud_free_mask, 0.0)

# 3. Normalize: Landsat Int16 (0-10000) -> Float32 (0-1)
ds_norm = ds_masked.astype('float32') / 10000.0

# 4. Stack Bands: Convert from xarray.Dataset to xarray.DataArray (C, Y, X)
ds_final = ds_norm.to_array(dim='band')


# --- Step C: Process Target Labels (Y) ---
labels_da = labels_ds[label_measurement]

# 1. Apply Mask to Labels: Replace cloud-covered pixels with 255 (No Data)
# This prevents the model from being penalized for trying to predict land cover 
# where we know the input data is corrupted (cloudy).
labels_masked = labels_da.where(cloud_free_mask, 255)

# 2. Convert to final format (typically Int64 for PyTorch labels)
labels_final = labels_masked.astype('int64')

# --- Step D: Rechunk for Training ---
# Reshapes the lazy Dask graph chunks to be efficient for GPU training batches
ds_final = ds_final.chunk({'time': 1, 'x': 512, 'y': 512})
labels_final = labels_final.chunk({'time': 1, 'x': 512, 'y': 512})

print("\nData Pipeline Ready for PyTorch Dataset.")


Data Pipeline Ready for PyTorch Dataset.


In [38]:

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from dask.distributed import get_client, Client

In [90]:
class LandCoverDataset(Dataset):
    """
    A PyTorch Dataset that loads one Landsat scene (input) and its corresponding 
    Land Cover map (target) for semantic segmentation, using Dask for lazy I/O.
    """
    
    def __init__(self, input_data, target_labels): 
        self.input_data = input_data      # ds_final (Input, Float32, C, Y, X)
        self.target_labels = target_labels # labels_final (Labels, Int64, Y, X)
        print(input_data)
        # FIX 2: Use a defined variable (input_data) to get the time dimension size
        self.num_samples = self.input_data.sizes['time'] 
        # Get the number of chunks in each dimension
        time_chunks = len(input_data.chunks[0])
        y_chunks = len(input_data.chunks[2])
        x_chunks = len(input_data.chunks[3])
        
        # 2. Store the chunk indices as attributes of the object (self)
        self.x_chunks = x_chunks
        self.y_chunks = y_chunks
        
        # 3. Calculate the total number of patches (samples)
        self.num_samples = time_chunks * y_chunks * x_chunks
        # --- Critical Dask Setup ---
        try:
            self.client = get_client()
        except ValueError:
            print("Warning: Dask client not found. Running computation locally.")
            self.client = None

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        if idx >= self.num_samples:
            raise IndexError("Index out of bounds")
        
        patches_per_scene = self.x_chunks * self.y_chunks
        time_idx = idx // patches_per_scene # Integer division
        
        # Calculate the Patch index within the scene
        patch_idx_flat = idx % patches_per_scene
        
        # Convert flat patch index to (y_idx, x_idx) grid coordinates
        y_idx = patch_idx_flat // self.x_chunks
        x_idx = patch_idx_flat % self.x_chunks
        
        
        # 2. Select the specific chunk (DataArray.chunks are available via Dask)
        # This selects the (time_idx) scene, and the (y_idx, x_idx) chunk within that scene.
        
        input_slice = self.input_data.isel(time=time_idx, x=slice(x_idx, x_idx+1), y=slice(y_idx, y_idx+1)).squeeze()
        label_slice = self.target_labels.isel(time=time_idx, x=slice(x_idx, x_idx+1), y=slice(y_idx, y_idx+1)).squeeze()
        
        # CRITICAL: We use .squeeze() because isel(slice) retains the size=1 dimension, 
        # but we only want (C, H, W) for the patch.
        
        # 3. Trigger computation on both slices using Dask (Remaining logic is the same)
        # Send both small computation graphs to the Dask cluster
        input_future, label_future = self.client.compute(
            [input_slice.data, label_slice.data]
        )
        input_np = input_future.result()
        label_np = label_future.result()

        
        # 4. To torch
        input_tensor = torch.from_numpy(input_np.astype("float32"))
        target_tensor = torch.from_numpy(label_np.astype("int64"))
        
        return input_tensor, target_tensor



In [91]:
# --- Instantiate and Test ---
# print(dataset)          # or
# print(dataset.data_vars)

# 1. Create the Dataset instance
eo_dataset = LandCoverDataset(ds_final, labels_ds)

# NOTE: The print below will still show the large shape unless you fix the chunking!
print(f"Total number of scenes (samples) in Dataset: {len(eo_dataset)}")
print(f"Shape of one sample (C, Y, X): {ds_final.isel(time=0).shape}")

# 2. Create a basic PyTorch DataLoader (num_workers=0 for now)
batch_size = 4
eo_dataloader = DataLoader(
    eo_dataset, 
    batch_size=batch_size, 
    shuffle=True, 
    num_workers=0 
)

# 3. Test retrieving one batch of data
# FIX APPLIED HERE: Removed the 'level' keyword argument
heading("Testing Dask to PyTorch Data Flow") 

# Iterate through the DataLoader to fetch the first batch
# This will trigger the Dask computation for the first 4 scenes
data_batch, labels_batch = next(iter(eo_dataloader))

print(f"\nSuccessfully loaded one batch of data!")
print(f"Batch Tensor Shape (B, C, H, W): {data_batch.shape}")
print(f"Batch Label Shape: {labels_batch.shape}")
print(f"Data type: {data_batch.dtype}")
print("-" * 40)

print("Check your Dask Dashboard link to see activity during the load!")

<xarray.DataArray (band: 6, time: 43, y: 6441, x: 5469)> Size: 36GB
dask.array<rechunk-merge, shape=(6, 43, 6441, 5469), dtype=float32, chunksize=(1, 1, 512, 512), chunktype=numpy.ndarray>
Coordinates:
  * time         (time) datetime64[ns] 344B 2020-02-08T23:56:31.774346 ... 20...
  * y            (y) float64 52kB -3.922e+06 -3.922e+06 ... -4.115e+06
  * x            (x) float64 44kB 1.272e+06 1.272e+06 ... 1.436e+06 1.436e+06
    spatial_ref  int32 4B 3577
  * band         (band) object 48B 'nbart_blue' 'nbart_green' ... 'nbart_swir_2'
Attributes:
    crs:           EPSG:3577
    grid_mapping:  spatial_ref
Total number of scenes (samples) in Dataset: 858
Shape of one sample (C, Y, X): (6, 6441, 5469)


NameError: name 'input_da' is not defined

In [63]:
client.close()
cluster.shutdown()