## Data scripts

This notebook contains the code and tests for the data processing and dataset logic for this project

In [1]:
import xarray
import zarr
from numcodecs.zarr3 import PCodec
import numpy
import os 
from torch.utils.data import Dataset, IterableDataset
from helpers import set_seed, set_device, get_logger
import warnings
from tqdm import tqdm
import h5py


LOGGER = get_logger("worldclim-dataset")

#### Data Processing

In [2]:
def _find_dataset(path: str) -> str:
    """
    Checks that the dataset exists at the given path
    
    Args:
        path (str): The path to the dataset
        
    Returns:
        str: Confirmation that the dataset exists
        
    Raises:
        FileNotFoundError: If the dataset does not exist
    """
    if os.path.exists(path):
        return f"Dataset found at {path}"
    else:
        raise FileNotFoundError(f"Dataset not found at {path}")
        

In [3]:
path = r"../data/worldclim2.zarr"
_find_dataset(path)

'Dataset found at ../data/worldclim2.zarr'

In [4]:
def _load_dataset(path: str) -> xarray.Dataset:
    """
    Loads the dataset from the given path
    
    Args:
        path (str): The path to the dataset
        
    Returns:
        xarray.Dataset: The dataset
    """
    LOGGER.info(f"Loading dataset from {path}")
    warnings.filterwarnings(
        "ignore",
        message="Numcodecs codecs are not in the Zarr version 3 specification.*",
        category=UserWarning,
        module="numcodecs.zarr3"
    )
    try: 
        dataset = xarray.open_zarr(path)
        dataset = dataset
        print(f"Dataset loaded from {path}")
        return dataset
    except Exception as e:
        raise ValueError(f"Error loading dataset from {path}: {e}")

In [5]:
ds = _load_dataset(path)
ds

Dataset loaded from ../data/worldclim2.zarr


Unnamed: 0,Array,Chunk
Bytes,83.43 GiB,2.00 MiB
Shape,"(12, 21600, 43200)","(1, 512, 512)"
Dask graph,43860 chunks in 2 graph layers,43860 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 83.43 GiB 2.00 MiB Shape (12, 21600, 43200) (1, 512, 512) Dask graph 43860 chunks in 2 graph layers Data type float64 numpy.ndarray",43200  21600  12,

Unnamed: 0,Array,Chunk
Bytes,83.43 GiB,2.00 MiB
Shape,"(12, 21600, 43200)","(1, 512, 512)"
Dask graph,43860 chunks in 2 graph layers,43860 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,83.43 GiB,2.00 MiB
Shape,"(12, 21600, 43200)","(1, 512, 512)"
Dask graph,43860 chunks in 2 graph layers,43860 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 83.43 GiB 2.00 MiB Shape (12, 21600, 43200) (1, 512, 512) Dask graph 43860 chunks in 2 graph layers Data type float64 numpy.ndarray",43200  21600  12,

Unnamed: 0,Array,Chunk
Bytes,83.43 GiB,2.00 MiB
Shape,"(12, 21600, 43200)","(1, 512, 512)"
Dask graph,43860 chunks in 2 graph layers,43860 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,83.43 GiB,2.00 MiB
Shape,"(12, 21600, 43200)","(1, 512, 512)"
Dask graph,43860 chunks in 2 graph layers,43860 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 83.43 GiB 2.00 MiB Shape (12, 21600, 43200) (1, 512, 512) Dask graph 43860 chunks in 2 graph layers Data type float64 numpy.ndarray",43200  21600  12,

Unnamed: 0,Array,Chunk
Bytes,83.43 GiB,2.00 MiB
Shape,"(12, 21600, 43200)","(1, 512, 512)"
Dask graph,43860 chunks in 2 graph layers,43860 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,83.43 GiB,2.00 MiB
Shape,"(12, 21600, 43200)","(1, 512, 512)"
Dask graph,43860 chunks in 2 graph layers,43860 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 83.43 GiB 2.00 MiB Shape (12, 21600, 43200) (1, 512, 512) Dask graph 43860 chunks in 2 graph layers Data type float64 numpy.ndarray",43200  21600  12,

Unnamed: 0,Array,Chunk
Bytes,83.43 GiB,2.00 MiB
Shape,"(12, 21600, 43200)","(1, 512, 512)"
Dask graph,43860 chunks in 2 graph layers,43860 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,83.43 GiB,2.00 MiB
Shape,"(12, 21600, 43200)","(1, 512, 512)"
Dask graph,43860 chunks in 2 graph layers,43860 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 83.43 GiB 2.00 MiB Shape (12, 21600, 43200) (1, 512, 512) Dask graph 43860 chunks in 2 graph layers Data type float64 numpy.ndarray",43200  21600  12,

Unnamed: 0,Array,Chunk
Bytes,83.43 GiB,2.00 MiB
Shape,"(12, 21600, 43200)","(1, 512, 512)"
Dask graph,43860 chunks in 2 graph layers,43860 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,83.43 GiB,2.00 MiB
Shape,"(12, 21600, 43200)","(1, 512, 512)"
Dask graph,43860 chunks in 2 graph layers,43860 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 83.43 GiB 2.00 MiB Shape (12, 21600, 43200) (1, 512, 512) Dask graph 43860 chunks in 2 graph layers Data type float64 numpy.ndarray",43200  21600  12,

Unnamed: 0,Array,Chunk
Bytes,83.43 GiB,2.00 MiB
Shape,"(12, 21600, 43200)","(1, 512, 512)"
Dask graph,43860 chunks in 2 graph layers,43860 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,83.43 GiB,2.00 MiB
Shape,"(12, 21600, 43200)","(1, 512, 512)"
Dask graph,43860 chunks in 2 graph layers,43860 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 83.43 GiB 2.00 MiB Shape (12, 21600, 43200) (1, 512, 512) Dask graph 43860 chunks in 2 graph layers Data type float64 numpy.ndarray",43200  21600  12,

Unnamed: 0,Array,Chunk
Bytes,83.43 GiB,2.00 MiB
Shape,"(12, 21600, 43200)","(1, 512, 512)"
Dask graph,43860 chunks in 2 graph layers,43860 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,83.43 GiB,2.00 MiB
Shape,"(12, 21600, 43200)","(1, 512, 512)"
Dask graph,43860 chunks in 2 graph layers,43860 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 83.43 GiB 2.00 MiB Shape (12, 21600, 43200) (1, 512, 512) Dask graph 43860 chunks in 2 graph layers Data type float64 numpy.ndarray",43200  21600  12,

Unnamed: 0,Array,Chunk
Bytes,83.43 GiB,2.00 MiB
Shape,"(12, 21600, 43200)","(1, 512, 512)"
Dask graph,43860 chunks in 2 graph layers,43860 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [6]:
def _land_compute_mask(dataset: xarray.Dataset, land_mask_value: float = -32768, land_mask_variable: str ='elev') -> xarray.DataArray:
    """
    Computes the mask for the dataset
    
    Args:
        dataset (xarray.Dataset): The dataset
        land_mask_value (float): The value to use for the mask
        land_mask_variable (str): The variable to use for the mask. Defaults to 'elev'.
        
    Returns:
        xarray.DataArray: The mask
    """
    LOGGER.info(f"Computing mask for {land_mask_variable} with value {land_mask_value}")
    
    # check that the variable exists
    if land_mask_variable not in dataset.data_vars:
        raise ValueError(f"Variable {land_mask_variable} not found in dataset")
    
    # create mask: True where the data is not equal to the mask_value
    land_mask = dataset[land_mask_variable].isel(t=0) != land_mask_value
    
    land_count = land_mask.values.sum()
    total_count = land_mask.values.size
    print(f"Land count: {land_count}")
    print(f"Total count: {total_count}")
    print(f"percentage land: {land_count / total_count * 100}")
    LOGGER.info(f"Land count: {land_count}")
    LOGGER.info(f"Total count: {total_count}")
    del land_count
    del total_count
    
    LOGGER.info(f"Mask computed for {land_mask_variable} with value {land_mask_value}")
    print(f"Mask computed for {land_mask_variable} with value {land_mask_value}")
    return land_mask
    
    

In [7]:
mask = _land_compute_mask(ds)

Land count: 309278141
Total count: 933120000
percentage land: 33.144519568758575
Mask computed for elev with value -32768


In [8]:
mask.isel(x=0, y=0).coords.values

<bound method Mapping.values of Coordinates:
    x        float64 8B -180.0
    y        float64 8B 90.0
    t        int32 4B 1>

In [9]:
mask

Unnamed: 0,Array,Chunk
Bytes,889.89 MiB,256.00 kiB
Shape,"(21600, 43200)","(512, 512)"
Dask graph,3655 chunks in 4 graph layers,3655 chunks in 4 graph layers
Data type,bool numpy.ndarray,bool numpy.ndarray
"Array Chunk Bytes 889.89 MiB 256.00 kiB Shape (21600, 43200) (512, 512) Dask graph 3655 chunks in 4 graph layers Data type bool numpy.ndarray",43200  21600,

Unnamed: 0,Array,Chunk
Bytes,889.89 MiB,256.00 kiB
Shape,"(21600, 43200)","(512, 512)"
Dask graph,3655 chunks in 4 graph layers,3655 chunks in 4 graph layers
Data type,bool numpy.ndarray,bool numpy.ndarray


In [10]:
34556//42300

0

In [11]:
def _split_land_ocean_coord_generator(mask: xarray.DataArray):
    """
    Splits the dataset into land and ocean coordinates
    
    Args:
        mask (xarray.DataArray): The mask for the dataset
    
    Returns:
        tuple: The land and ocean datasets
    """
    LOGGER.info("Splitting dataset into land and ocean based on mask")
    
    x_values = mask.coords['x'].values
    y_values = mask.coords['y'].values
    
    width = len(x_values)
    
    land_mask_values = mask.values.ravel()
    total = len(land_mask_values)
    
    for i, is_land in enumerate(tqdm(land_mask_values, desc="processing the land mask", total=total)):
        row = i // width
        col = i % width
        coord = (y_values[row], x_values[col])
        yield ('land' if is_land else 'ocean'), coord

In [18]:
def _split_land_ocean_coords(land_mask: xarray.DataArray, output_file: str):
    """
    Splits the land_mask into land and ocean coordinates, writes them to HDF5.
    
    Args:
        land_mask (xarray.DataArray): A boolean mask (True = land, False = ocean).
        output_file (str): Output path for HDF5 file.
    """
    LOGGER.info("Splitting dataset into land and ocean using the generator")

    with h5py.File(output_file, "w") as f:
        land_coords = f.create_dataset("land_coords", (0, 2), maxshape=(None, 2), dtype="f8")
        ocean_coords = f.create_dataset("ocean_coords", (0, 2), maxshape=(None, 2), dtype="f8")

        land_count = 0
        ocean_count = 0

        for coord_type, coord in _split_land_ocean_coord_generator(land_mask):
            if coord_type == 'land':
                land_coords.resize((land_count + 1, 2))
                land_coords[land_count] = coord
                land_count += 1
            else:
                ocean_coords.resize((ocean_count + 1, 2))
                ocean_coords[ocean_count] = coord
                ocean_count += 1

        # Optional: store metadata
        f.attrs["land_count"] = land_count
        f.attrs["ocean_count"] = ocean_count
        f.attrs["total"] = land_count + ocean_count

        LOGGER.info(f"Finished writing: {land_count} land, {ocean_count} ocean coordinates")

        # Example of checking contents
        print(f"Land count: {land_count}")
        print(f"Ocean count: {ocean_count}")
        print(f"Total: {land_count + ocean_count}")

In [None]:
coordinates = _split_land_ocean_coords(mask, 'coordinates.h5')

processing the land mask:   0%|          | 113852/933120000 [00:03<8:57:50, 28911.86it/s]

In [43]:
coordinates['land'][0]

NameError: name 'coordinates' is not defined

In [30]:
ocean_ds

Coordinates:
  * x        (x) float64 346kB -180.0 -180.0 -180.0 -180.0 ... 180.0 180.0 180.0
  * y        (y) float64 168kB 90.0 89.99 89.98 89.97 ... -85.21 -85.22 -85.23
    t        int32 4B 1

In [41]:
land_ds.sizes

Frozen({'y': 20246, 'x': 43200})

In [42]:
ocean_ds.sizes

Frozen({'y': 21028, 'x': 43200})

In [51]:
len(land_ds['y'].values)

20246

In [52]:
len(land_ds['x'].values)

43200

In [53]:
land_ds.to_index()

MultiIndex([( 83.65416666666667, -179.99583333333334),
            ( 83.65416666666667,           -179.9875),
            ( 83.65416666666667, -179.97916666666666),
            ( 83.65416666666667, -179.97083333333333),
            ( 83.65416666666667,           -179.9625),
            ( 83.65416666666667, -179.95416666666668),
            ( 83.65416666666667, -179.94583333333333),
            ( 83.65416666666667,           -179.9375),
            ( 83.65416666666667, -179.92916666666667),
            ( 83.65416666666667, -179.92083333333335),
            ...
            (-89.99583333333334,  179.92083333333335),
            (-89.99583333333334,  179.92916666666667),
            (-89.99583333333334,            179.9375),
            (-89.99583333333334,  179.94583333333333),
            (-89.99583333333334,  179.95416666666665),
            (-89.99583333333334,  179.96249999999998),
            (-89.99583333333334,   179.9708333333333),
            (-89.99583333333334,  179.97916666666

In [None]:
ocean_ds.to_index()

In [44]:
conda list

# packages in environment at C:\Users\micke\anaconda3\envs\worldclim:
#
# Name                    Version                   Build  Channel
_openmp_mutex             4.5                       2_gnu    conda-forge
annotated-types           0.7.0                    pypi_0    pypi
asttokens                 3.0.0                    pypi_0    pypi
aws-c-auth                0.8.0               h2219d47_15    conda-forge
aws-c-cal                 0.8.1                h099ea23_3    conda-forge
aws-c-common              0.10.6               h2466b09_0    conda-forge
aws-c-compression         0.3.0                h099ea23_5    conda-forge
aws-c-event-stream        0.5.0               h85d8506_11    conda-forge
aws-c-http                0.9.2                h3888f84_4    conda-forge
aws-c-io                  0.15.3               hc5a9e45_6    conda-forge
aws-c-mqtt                0.11.0              h2c94728_12    conda-forge
aws-c-s3                  0.7.7                h6a38c86_0    conda-forge