## Data scripts

This notebook contains the code and tests for the data processing and dataset logic for this project

In [10]:
import xarray
import zarr
from numcodecs.zarr3 import PCodec
import numpy
import os 
from torch.utils.data import Dataset, IterableDataset
from helpers import set_seed, set_device, get_logger
import warnings
from tqdm import tqdm
import h5py
import json


LOGGER = get_logger(name = "worldclim-dataset", log_file="worldclim-dataset.log")

#### Data Processing

In [3]:
def _find_dataset(path: str) -> str:
    """
    Checks that the dataset exists at the given path
    
    Args:
        path (str): The path to the dataset
        
    Returns:
        str: Confirmation that the dataset exists
        
    Raises:
        FileNotFoundError: If the dataset does not exist
    """
    LOGGER.info(f"Checking that dataset exists at {path}")
    if os.path.exists(path):
        LOGGER.info(f"Dataset found at {path}")
        LOGGER.info("____________________________________________")
        return f"Dataset found at {path}"
    else:
        raise FileNotFoundError(f"Dataset not found at {path}")
        

In [5]:
path = r"../data/worldclim2.zarr"
_find_dataset(path)

'Dataset found at ../data/worldclim2.zarr'

In [4]:
def _load_dataset(path: str) -> xarray.Dataset:
    """
    Loads the dataset from the given path
    
    Args:
        path (str): The path to the dataset
        
    Returns:
        xarray.Dataset: The dataset
    """
    LOGGER.info(f"Loading dataset from {path}")
    warnings.filterwarnings(
        "ignore",
        message="Numcodecs codecs are not in the Zarr version 3 specification.*",
        category=UserWarning,
        module="numcodecs.zarr3"
    )
    try: 
        dataset = xarray.open_zarr(path)
        dataset = dataset
        print(f"Dataset loaded from {path}")
        LOGGER.info("DATASET LOADED")
        LOGGER.info("____________________________________________")
        return dataset
    except Exception as e:
        raise ValueError(f"Error loading dataset from {path}: {e}")

In [6]:
ds = _load_dataset(path)
ds

Dataset loaded from ../data/worldclim2.zarr


Unnamed: 0,Array,Chunk
Bytes,83.43 GiB,2.00 MiB
Shape,"(12, 21600, 43200)","(1, 512, 512)"
Dask graph,43860 chunks in 2 graph layers,43860 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 83.43 GiB 2.00 MiB Shape (12, 21600, 43200) (1, 512, 512) Dask graph 43860 chunks in 2 graph layers Data type float64 numpy.ndarray",43200  21600  12,

Unnamed: 0,Array,Chunk
Bytes,83.43 GiB,2.00 MiB
Shape,"(12, 21600, 43200)","(1, 512, 512)"
Dask graph,43860 chunks in 2 graph layers,43860 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,83.43 GiB,2.00 MiB
Shape,"(12, 21600, 43200)","(1, 512, 512)"
Dask graph,43860 chunks in 2 graph layers,43860 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 83.43 GiB 2.00 MiB Shape (12, 21600, 43200) (1, 512, 512) Dask graph 43860 chunks in 2 graph layers Data type float64 numpy.ndarray",43200  21600  12,

Unnamed: 0,Array,Chunk
Bytes,83.43 GiB,2.00 MiB
Shape,"(12, 21600, 43200)","(1, 512, 512)"
Dask graph,43860 chunks in 2 graph layers,43860 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,83.43 GiB,2.00 MiB
Shape,"(12, 21600, 43200)","(1, 512, 512)"
Dask graph,43860 chunks in 2 graph layers,43860 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 83.43 GiB 2.00 MiB Shape (12, 21600, 43200) (1, 512, 512) Dask graph 43860 chunks in 2 graph layers Data type float64 numpy.ndarray",43200  21600  12,

Unnamed: 0,Array,Chunk
Bytes,83.43 GiB,2.00 MiB
Shape,"(12, 21600, 43200)","(1, 512, 512)"
Dask graph,43860 chunks in 2 graph layers,43860 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,83.43 GiB,2.00 MiB
Shape,"(12, 21600, 43200)","(1, 512, 512)"
Dask graph,43860 chunks in 2 graph layers,43860 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 83.43 GiB 2.00 MiB Shape (12, 21600, 43200) (1, 512, 512) Dask graph 43860 chunks in 2 graph layers Data type float64 numpy.ndarray",43200  21600  12,

Unnamed: 0,Array,Chunk
Bytes,83.43 GiB,2.00 MiB
Shape,"(12, 21600, 43200)","(1, 512, 512)"
Dask graph,43860 chunks in 2 graph layers,43860 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,83.43 GiB,2.00 MiB
Shape,"(12, 21600, 43200)","(1, 512, 512)"
Dask graph,43860 chunks in 2 graph layers,43860 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 83.43 GiB 2.00 MiB Shape (12, 21600, 43200) (1, 512, 512) Dask graph 43860 chunks in 2 graph layers Data type float64 numpy.ndarray",43200  21600  12,

Unnamed: 0,Array,Chunk
Bytes,83.43 GiB,2.00 MiB
Shape,"(12, 21600, 43200)","(1, 512, 512)"
Dask graph,43860 chunks in 2 graph layers,43860 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,83.43 GiB,2.00 MiB
Shape,"(12, 21600, 43200)","(1, 512, 512)"
Dask graph,43860 chunks in 2 graph layers,43860 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 83.43 GiB 2.00 MiB Shape (12, 21600, 43200) (1, 512, 512) Dask graph 43860 chunks in 2 graph layers Data type float64 numpy.ndarray",43200  21600  12,

Unnamed: 0,Array,Chunk
Bytes,83.43 GiB,2.00 MiB
Shape,"(12, 21600, 43200)","(1, 512, 512)"
Dask graph,43860 chunks in 2 graph layers,43860 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,83.43 GiB,2.00 MiB
Shape,"(12, 21600, 43200)","(1, 512, 512)"
Dask graph,43860 chunks in 2 graph layers,43860 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 83.43 GiB 2.00 MiB Shape (12, 21600, 43200) (1, 512, 512) Dask graph 43860 chunks in 2 graph layers Data type float64 numpy.ndarray",43200  21600  12,

Unnamed: 0,Array,Chunk
Bytes,83.43 GiB,2.00 MiB
Shape,"(12, 21600, 43200)","(1, 512, 512)"
Dask graph,43860 chunks in 2 graph layers,43860 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,83.43 GiB,2.00 MiB
Shape,"(12, 21600, 43200)","(1, 512, 512)"
Dask graph,43860 chunks in 2 graph layers,43860 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 83.43 GiB 2.00 MiB Shape (12, 21600, 43200) (1, 512, 512) Dask graph 43860 chunks in 2 graph layers Data type float64 numpy.ndarray",43200  21600  12,

Unnamed: 0,Array,Chunk
Bytes,83.43 GiB,2.00 MiB
Shape,"(12, 21600, 43200)","(1, 512, 512)"
Dask graph,43860 chunks in 2 graph layers,43860 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray


In [6]:
def _land_compute_mask(dataset: xarray.Dataset, land_mask_value: float = -32768, land_mask_variable: str ='elev') -> xarray.DataArray:
    """
    Computes the mask for the dataset
    
    Args:
        dataset (xarray.Dataset): The dataset
        land_mask_value (float): The value to use for the mask
        land_mask_variable (str): The variable to use for the mask. Defaults to 'elev'.
        
    Returns:
        xarray.DataArray: The mask
    """
    LOGGER.info(f"Computing mask for {land_mask_variable} with value {land_mask_value}")
    
    # check that the variable exists
    if land_mask_variable not in dataset.data_vars:
        raise ValueError(f"Variable {land_mask_variable} not found in dataset")
    
    # create mask: True where the data is not equal to the mask_value
    land_mask = dataset[land_mask_variable].isel(t=0) != land_mask_value
    
    land_count = land_mask.values.sum()
    total_count = land_mask.values.size
    print(f"Land count: {land_count}")
    print(f"Total count: {total_count}")
    print(f"percentage land: {land_count / total_count * 100}")
    LOGGER.info(f"Land count: {land_count}")
    LOGGER.info(f"Total count: {total_count}")
    del land_count
    del total_count
    
    LOGGER.info(f"MASK COMPUTED FOR {land_mask_variable} WITH VALUE {land_mask_value}")
    print(f"Mask computed for {land_mask_variable} with value {land_mask_value}")
    LOGGER.info("____________________________________________")
    return land_mask
    
    

In [7]:
mask = _land_compute_mask(ds)

Land count: 309278141
Total count: 933120000
percentage land: 33.144519568758575
Mask computed for elev with value -32768


In [8]:
mask.sel(x=-33.87916667, y=83.65416667, method='nearest')

Unnamed: 0,Array,Chunk
Bytes,1 B,1 B
Shape,(),()
Dask graph,1 chunks in 5 graph layers,1 chunks in 5 graph layers
Data type,bool numpy.ndarray,bool numpy.ndarray
Array Chunk Bytes 1 B 1 B Shape () () Dask graph 1 chunks in 5 graph layers Data type bool numpy.ndarray,,

Unnamed: 0,Array,Chunk
Bytes,1 B,1 B
Shape,(),()
Dask graph,1 chunks in 5 graph layers,1 chunks in 5 graph layers
Data type,bool numpy.ndarray,bool numpy.ndarray


In [9]:
34556//42300

0

In [11]:
def _split_land_ocean_coord_generator(mask: xarray.DataArray):
    """
    Splits the dataset into land and ocean coordinates
    
    Args:
        mask (xarray.DataArray): The mask for the dataset
    
    Returns:
        tuple: The land and ocean datasets
    """
    LOGGER.info("Splitting dataset into land and ocean based on mask")
    
    x_values = mask.coords['x'].values
    y_values = mask.coords['y'].values
    
    width = len(x_values)
    
    land_mask_values = mask.values.ravel()
    total = len(land_mask_values)
    
    for i, is_land in enumerate(tqdm(land_mask_values, desc="processing the land mask", total=total)):
        row = i // width
        col = i % width
        coord = (y_values[row], x_values[col])
        yield ('land' if is_land else 'ocean'), coord
    

In [14]:
def _split_land_ocean_coords(land_mask: xarray.DataArray, output_file: str):
    """
    Splits the land_mask into land and ocean coordinates, writes them to HDF5.
    
    Args:
        land_mask (xarray.DataArray): A boolean mask (True = land, False = ocean).
        output_file (str): Output path for HDF5 file.
    """
    LOGGER.info("SPLITTING LAND AND OCEAN COORDINATES")

    with h5py.File(output_file, "w") as f:
        land_coords = f.create_dataset("land_coords", (0, 2), maxshape=(None, 2), dtype="f8", compression="gzip")
        ocean_coords = f.create_dataset("ocean_coords", (0, 2), maxshape=(None, 2), dtype="f8", compression="gzip")

        land_count = 0
        ocean_count = 0

        for coord_type, coord in _split_land_ocean_coord_generator(land_mask):
            if coord_type == 'land':
                land_coords.resize((land_count + 1, 2))
                land_coords[land_count] = coord
                land_count += 1
            else:
                ocean_coords.resize((ocean_count + 1, 2))
                ocean_coords[ocean_count] = coord
                ocean_count += 1

        # Optional: store metadata
        f.attrs["land_count"] = land_count
        f.attrs["ocean_count"] = ocean_count
        f.attrs["total"] = land_count + ocean_count

        LOGGER.info(f"FINISHED SPLITTING LAND AND OCEAN COORDS")
        LOGGER.info("____________________________________________")

        # Example of checking contents
        print(f"Land count: {land_count}")
        print(f"Ocean count: {ocean_count}")
        print(f"Total: {land_count + ocean_count}")

In [19]:
coordinates = _split_land_ocean_coords(mask, 'coordinates.h5')

processing the land mask: 100%|██████████| 933120000/933120000 [8:24:28<00:00, 30827.62it/s]   


Land count: 309278141
Ocean count: 623841859
Total: 933120000


In [15]:
f = h5py.File('coordinates.h5', 'r')
f

<HDF5 file "coordinates.h5" (mode r)>

In [16]:
f.keys()

<KeysViewHDF5 ['land_coords', 'ocean_coords']>

In [17]:
 f['land_coords'][0]

array([ 83.65416667, -33.87916667])

In [11]:
def _get_normalized_stats(dataset: xarray.Dataset):
    """
    Computes the normalized statistics for the dataset
    
    Args:
        dataset (xarray.Dataset): The dataset
    
    Returns:
        dict: The normalized statistics for the dataset
    """
    LOGGER.info("COMMENCING COMPUTING NORMALIZED STATS")
    
    normalized_stats = {}
    condition = dataset['elev'] != -32768.0
    
    for variable in dataset.data_vars:
        LOGGER.info(f"Computing normalized statistics for {variable}")
        print(f"Computing normalized statistics for {variable}")
        
        
        masked = dataset[variable].where(condition)
        min_value = masked.min().compute()
        max_value = masked.max().compute()
        
        normalized_stats[variable] = [min_value.item(), max_value.item()]
        del min_value, max_value
        
    with open('../normalized_stats.json', 'w') as f:
        json.dump(normalized_stats, f) 
        
    LOGGER.info("COMPLETED COMPUTING NORMALIZED STATS")
    LOGGER.info("____________________________________________")
    return normalized_stats

In [12]:
_get_normalized_stats(ds)

Computing normalized statistics for tavg
Computing normalized statistics for tmax
Computing normalized statistics for elev
Computing normalized statistics for vapr
Computing normalized statistics for wind
Computing normalized statistics for prec
Computing normalized statistics for srad
Computing normalized statistics for tmin


{'tavg': [-68.5, 39.900001525878906],
 'tmax': [-64.5, 48.599998474121094],
 'elev': [-415.0, 8424.0],
 'vapr': [0.0, 3.569999933242798],
 'wind': [0.0, 36.099998474121094],
 'prec': [0.0, 2982.0],
 'srad': [0.0, 45989.0],
 'tmin': [-72.5999984741211, 32.400001525878906]}

In [18]:
def run():
    if not os.path.exists('coordinates.h5'):
        LOGGER.info("RUNNING")
        _find_dataset(path)
        dataset = _load_dataset(path)
        mask = _land_compute_mask(dataset)
        _split_land_ocean_coords(mask, 'coordinates.h5')
    if not os.path.exists('normalized_stats.json'):
        _get_normalized_stats(dataset)   
    LOGGER.info("COMPLETED RUNNING")
    LOGGER.info("____________________________________________")

MemoryError: Unable to allocate 83.4 GiB for an array with shape (12, 21600, 43200) and data type float64

In [15]:
# Path to your JSON file
with open('normalized_stats.json', 'r') as f:
    normalized_stats = json.load(f)

# Now you can access it like a regular Python dictionary
print(normalized_stats)
print(normalized_stats.keys())

{'tavg': [-68.5, 39.900001525878906], 'tmax': [-64.5, 48.599998474121094], 'elev': [-415.0, 8424.0], 'vapr': [0.0, 3.569999933242798], 'wind': [0.0, 36.099998474121094], 'prec': [0.0, 2982.0], 'srad': [0.0, 45989.0], 'tmin': [-72.5999984741211, 32.400001525878906]}
dict_keys(['tavg', 'tmax', 'elev', 'vapr', 'wind', 'prec', 'srad', 'tmin'])
