In [0]:
import rasterio
from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition
from pyspark.sql.types import StructType, StructField, StringType, DoubleType, MapType, LongType, TimestampType
from datetime import datetime, UTC
import numpy as np
import os

## Design decision: When to split files?

Step one is to think about how we might want to parallise / partition the read process. By default, Spark will create a task per input file it reads, but we have a relatively small number of files to work with here and reading any one of these into memory in its entirety might result in an out-of-memory situation.

The good news is: we can take advantage of the hierarchical nature of these files to parallelise the read process and scale it out across a large cluster. All we need to do is extend the `pyspark.sql.datasource.InputPartition` class and implement a method called `partitions` in our `DataSourceReader` that exposes the full set of partitions to be read.

We'll partition across both **subdataset** and **band**.

In [0]:
class SubdatasetBandPartition(InputPartition):
  def __init__(self, subdataset: str, band: int):
    self.subdataset = subdataset
    self.band = band

Here's our `ERA5DataSourceReader` class. It uses the `rasterio` library to read the metadata and contents of each file and yield a 7-tuple per data point inside.

In [0]:
class ERA5DataSourceReader(DataSourceReader):
  def __init__(self, schema, options):
    self.schema: StructType = schema
    self.options = options

  @staticmethod
  def unix_to_datetime(unix_timestamp: int) -> datetime:
    return datetime.fromtimestamp(unix_timestamp, UTC)
  
  @staticmethod
  def convert_longitude_0_360_to_180(lon_array):
    """
    Convert longitude from 0-360° range to -180 to +180° range
    
    Parameters:
    lon_array (numpy.ndarray): Array of longitudes in 0-360° range
    
    Returns:
    numpy.ndarray: Array of longitudes in -180 to +180° range
    """
    # Create a copy to avoid modifying the original array
    converted = np.copy(lon_array)
    
    # Find values > 180 and subtract 360 from them
    converted = np.where(converted > 180, converted - 360, converted)
    
    return converted
  
  @property
  def input_path(self):
    path: str = self.options.get("path")
    if not path:
      raise ValueError("The 'path' option is required.")
    return path
  
  @property
  def input_files(self):
    files = []
    for root, _, filenames in os.walk(self.input_path):
      for filename in filenames:
        if filename.endswith(".nc"):
          files.append(os.path.join(root, filename))
    return files
  
  def partitions(self):
    """For each file, use rasterio to open and grab the relevant hierarchy of data and create a list of partitions using our custom partitioning class: SubdatasetBandPartition(subdataset, band)"""
    parts: list[SubdatasetBandPartition] = []
    for path in self.input_files:
      with rasterio.open(path) as src:
        for sd in src.subdatasets:
          with rasterio.open(sd) as r:
            parts += [SubdatasetBandPartition(sd, b + 1) for b in range(r.count)]
    return parts

  def read(self, partition):
    """For each partition, use rasterio to open and grab the relevant hierarchy of data and return a generator of tuples containing the subdataset, metadata, band index, coordinates, and values"""
    # In case of an empty directory or no files found
    if not getattr(partition, 'subdataset', None):
      return
    with rasterio.open(partition.subdataset) as r:
      metadata_tags = r.tags()
      del metadata_tags["NETCDF_DIM_valid_time_VALUES"]
      image_height, image_width = r.shape
      cols, rows = np.meshgrid(np.arange(image_width), np.arange(image_height))
      xs, ys = rasterio.transform.xy(r.transform, rows, cols) # , offset="ul"
      timesteps = r.tags()["NETCDF_DIM_valid_time_VALUES"][1:-1].split(",")
      valid_time = self.unix_to_datetime(int(timesteps[partition.band - 1]))
      values = r.read(partition.band)
      value_count = len(values)
      rows = zip(
        self.convert_longitude_0_360_to_180(xs),
        ys,
        values.astype(np.float32).flatten()
        )
      for rw in rows:
          yield (partition.subdataset, metadata_tags, partition.band, valid_time, *rw)

**Create our ERA5DataSource**

we need to combine our reader, which controls distrubition and reading logic, with a Spark `DataSource` abstraction which tells Spark the schema of the data expected to be returned by the reader, and gives it a freindly name to be used in the call to SparkSession.read


In [0]:
class ERA5DataSource(DataSource):

    @classmethod
    def name(cls) -> str:
        """
        Get the name of the data source.

        Returns:
            str: The name of the data source.
        """
        return "era5"
    
    def schema(self) -> StructType:
        """
        Define the schema for the output data.

        Returns:
            StructType: The schema including fields for the variable identifier, band index,
            metadata, coordinates, and values.
        """
        return StructType([
            StructField("subdataset", StringType(), True),
            StructField("metadata", MapType(StringType(), StringType()), True),
            StructField("band", LongType(), True),
            StructField("time", TimestampType(), True),
            StructField("x", DoubleType(), True),
            StructField("y", DoubleType(), True),
            StructField("m", DoubleType(), True),
        ])

    def reader(self, schema: StructType):
        return ERA5DataSourceReader(schema, self.options)

In [0]:
## Register and use this datasource with:
# spark.dataSource.register(ERA5DataSource)
#spark.read.format("era5").load(path)