## Data scripts

This notebook contains the code and tests for the data processing and dataset logic for this project

In [1]:
import xarray
import zarr
from numcodecs.zarr3 import PCodec
import numpy
import os 
from torch.utils.data import Dataset, IterableDataset
from helpers import set_seed, set_device, get_logger
import warnings

LOGGER = get_logger("worldclim-dataset")

#### Data Processing

In [2]:
def _find_dataset(path: str) -> str:
    """
    Checks that the dataset exists at the given path
    
    Args:
        path (str): The path to the dataset
        
    Returns:
        str: Confirmation that the dataset exists
        
    Raises:
        FileNotFoundError: If the dataset does not exist
    """
    if os.path.exists(path):
        return f"Dataset found at {path}"
    else:
        raise FileNotFoundError(f"Dataset not found at {path}")
        

In [3]:
path = r"../data/worldclim2.zarr"
_find_dataset(path)

'Dataset found at ../data/worldclim2.zarr'

In [4]:
def _load_dataset(path: str) -> xarray.Dataset:
    """
    Loads the dataset from the given path
    
    Args:
        path (str): The path to the dataset
        
    Returns:
        xarray.Dataset: The dataset
    """
    LOGGER.info(f"Loading dataset from {path}")
    warnings.filterwarnings(
        "ignore",
        message="Numcodecs codecs are not in the Zarr version 3 specification.*",
        category=UserWarning,
        module="numcodecs.zarr3"
    )
    try: 
        dataset = xarray.open_zarr(path)
        dataset = dataset
        print(f"Dataset loaded from {path}")
        return dataset
    except Exception as e:
        raise ValueError(f"Error loading dataset from {path}: {e}")

In [5]:
ds = _load_dataset(path)
ds

Dataset loaded from ../data/worldclim2.zarr


Unnamed: 0,Array,Chunk
Bytes,83.43 GiB,2.00 MiB
Shape,"(12, 21600, 43200)","(1, 512, 512)"
Dask graph,43860 chunks in 2 graph layers,43860 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 83.43 GiB 2.00 MiB Shape (12, 21600, 43200) (1, 512, 512) Dask graph 43860 chunks in 2 graph layers Data type float64 numpy.ndarray",43200  21600  12,

Unnamed: 0,Array,Chunk
Bytes,83.43 GiB,2.00 MiB
Shape,"(12, 21600, 43200)","(1, 512, 512)"
Dask graph,43860 chunks in 2 graph layers,43860 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,83.43 GiB,2.00 MiB
Shape,"(12, 21600, 43200)","(1, 512, 512)"
Dask graph,43860 chunks in 2 graph layers,43860 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 83.43 GiB 2.00 MiB Shape (12, 21600, 43200) (1, 512, 512) Dask graph 43860 chunks in 2 graph layers Data type float64 numpy.ndarray",43200  21600  12,

Unnamed: 0,Array,Chunk
Bytes,83.43 GiB,2.00 MiB
Shape,"(12, 21600, 43200)","(1, 512, 512)"
Dask graph,43860 chunks in 2 graph layers,43860 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,83.43 GiB,2.00 MiB
Shape,"(12, 21600, 43200)","(1, 512, 512)"
Dask graph,43860 chunks in 2 graph layers,43860 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 83.43 GiB 2.00 MiB Shape (12, 21600, 43200) (1, 512, 512) Dask graph 43860 chunks in 2 graph layers Data type float64 numpy.ndarray",43200  21600  12,

Unnamed: 0,Array,Chunk
Bytes,83.43 GiB,2.00 MiB
Shape,"(12, 21600, 43200)","(1, 512, 512)"
Dask graph,43860 chunks in 2 graph layers,43860 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,83.43 GiB,2.00 MiB
Shape,"(12, 21600, 43200)","(1, 512, 512)"
Dask graph,43860 chunks in 2 graph layers,43860 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 83.43 GiB 2.00 MiB Shape (12, 21600, 43200) (1, 512, 512) Dask graph 43860 chunks in 2 graph layers Data type float64 numpy.ndarray",43200  21600  12,

Unnamed: 0,Array,Chunk
Bytes,83.43 GiB,2.00 MiB
Shape,"(12, 21600, 43200)","(1, 512, 512)"
Dask graph,43860 chunks in 2 graph layers,43860 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,83.43 GiB,2.00 MiB
Shape,"(12, 21600, 43200)","(1, 512, 512)"
Dask graph,43860 chunks in 2 graph layers,43860 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 83.43 GiB 2.00 MiB Shape (12, 21600, 43200) (1, 512, 512) Dask graph 43860 chunks in 2 graph layers Data type float64 numpy.ndarray",43200  21600  12,

Unnamed: 0,Array,Chunk
Bytes,83.43 GiB,2.00 MiB
Shape,"(12, 21600, 43200)","(1, 512, 512)"
Dask graph,43860 chunks in 2 graph layers,43860 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,83.43 GiB,2.00 MiB
Shape,"(12, 21600, 43200)","(1, 512, 512)"
Dask graph,43860 chunks in 2 graph layers,43860 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 83.43 GiB 2.00 MiB Shape (12, 21600, 43200) (1, 512, 512) Dask graph 43860 chunks in 2 graph layers Data type float64 numpy.ndarray",43200  21600  12,

Unnamed: 0,Array,Chunk
Bytes,83.43 GiB,2.00 MiB
Shape,"(12, 21600, 43200)","(1, 512, 512)"
Dask graph,43860 chunks in 2 graph layers,43860 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,83.43 GiB,2.00 MiB
Shape,"(12, 21600, 43200)","(1, 512, 512)"
Dask graph,43860 chunks in 2 graph layers,43860 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 83.43 GiB 2.00 MiB Shape (12, 21600, 43200) (1, 512, 512) Dask graph 43860 chunks in 2 graph layers Data type float64 numpy.ndarray",43200  21600  12,

Unnamed: 0,Array,Chunk
Bytes,83.43 GiB,2.00 MiB
Shape,"(12, 21600, 43200)","(1, 512, 512)"
Dask graph,43860 chunks in 2 graph layers,43860 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,83.43 GiB,2.00 MiB
Shape,"(12, 21600, 43200)","(1, 512, 512)"
Dask graph,43860 chunks in 2 graph layers,43860 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 83.43 GiB 2.00 MiB Shape (12, 21600, 43200) (1, 512, 512) Dask graph 43860 chunks in 2 graph layers Data type float64 numpy.ndarray",43200  21600  12,

Unnamed: 0,Array,Chunk
Bytes,83.43 GiB,2.00 MiB
Shape,"(12, 21600, 43200)","(1, 512, 512)"
Dask graph,43860 chunks in 2 graph layers,43860 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [6]:
def _land_compute_mask(dataset: xarray.Dataset, land_mask_value: float = -32768, land_mask_variable: str ='elev') -> xarray.DataArray:
    """
    Computes the mask for the dataset
    
    Args:
        dataset (xarray.Dataset): The dataset
        land_mask_value (float): The value to use for the mask
        land_mask_variable (str): The variable to use for the mask. Defaults to 'elev'.
        
    Returns:
        xarray.DataArray: The mask
    """
    LOGGER.info(f"Computing mask for {land_mask_variable} with value {land_mask_value}")
    
    # check that the variable exists
    if land_mask_variable not in dataset.data_vars:
        raise ValueError(f"Variable {land_mask_variable} not found in dataset")
    
    # create mask: True where the data is not equal to the mask_value
    land_mask = dataset[land_mask_variable].isel(t=0) != land_mask_value
    
    LOGGER.info(f"Mask computed for {land_mask_variable} with value {land_mask_value}")
    return land_mask
    
    

In [25]:
mask = _land_compute_mask(ds)
mask.sel(x=-91.25, y=-0.8, method='nearest').values

array(True)

In [28]:
def _split_land_ocean_coords(dataset: xarray.Dataset, land_mask: xarray.DataArray) -> tuple:
    """
    Splits the dataset into land and ocean coordinates
    
    Args:
        dataset (xarray.Dataset): The dataset
        land_mask (xarray.DataArray): The land mask
        
    Returns:
        dict[str, xarray.Dataset]: The land and ocean datasets
    """
    LOGGER.info("Splitting dataset into land and ocean based on mask")
    land_mask = land_mask.load()
    
    land = land_mask.where(land_mask, drop=True)
    ocean = land_mask.where(~land_mask, drop=True)
    
    land_coords = land.coords
    ocean_coords = ocean.coords
    LOGGER.info("Splitting complete")

    return land_coords, ocean_coords

In [29]:
land_ds, ocean_ds = _split_land_ocean_coords(ds, mask)
land_ds

Coordinates:
  * x        (x) float64 346kB -180.0 -180.0 -180.0 -180.0 ... 180.0 180.0 180.0
  * y        (y) float64 162kB 83.65 83.65 83.64 83.63 ... -89.98 -89.99 -90.0
    t        int32 4B 1

In [30]:
ocean_ds

Coordinates:
  * x        (x) float64 346kB -180.0 -180.0 -180.0 -180.0 ... 180.0 180.0 180.0
  * y        (y) float64 168kB 90.0 89.99 89.98 89.97 ... -85.21 -85.22 -85.23
    t        int32 4B 1

In [41]:
land_ds.sizes

Frozen({'y': 20246, 'x': 43200})

In [42]:
ocean_ds.sizes

Frozen({'y': 21028, 'x': 43200})

In [51]:
len(land_ds['y'].values)

20246

In [52]:
len(land_ds['x'].values)

43200

In [None]:
land_ds.to_index()