In [0]:
%pip install "xarray[io]==2024.03.0" --quiet

In [0]:
dbutils.library.restartPython()

In [0]:
from pyspark.sql.dataframe import DataFrame
from pyspark.sql import functions as F
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import *
from typing import List

import pandas as pd
import xarray as xr

In [0]:
csv_schema = (
  StructType()
    .add(StructField("data_var", StringType(), False))
    .add(StructField("year", IntegerType(), False))
    .add(StructField("lat", DoubleType(), False))
    .add(StructField("lon", DoubleType(), False))
)
for d in range(1, 365 + 1):
  csv_schema.add(StructField(f"day_{d}", DoubleType(), False))

print("...")

In [0]:
def process_slice(
  xds: xr.Dataset, data_var: str, lat_idx: int, lon_start_idx: int = None, lon_end_idx: int = None, save_dir=None
) -> str:
  """
  Populate a dict with result information.
  Returns either array of json str or the path saved if save_dir provided.
  - If lon_start_idx is not provided, then lon slice will start at 0
  - If lon_end_idx is not provided, then slice will end at the size of the lon array
  """
  import json
  import numpy as np
  import os

  # retrieve the path
  vol_path = xds.encoding['source']

  save_path = ''
  if save_dir:
      os.makedirs(save_dir, exist_ok=True)
      if lon_start_idx is not None and lon_end_idx is not None:
        save_path = f'{save_dir}/{data_var}_lat_{lat_idx}_lon_{lon_start_idx}-{lon_end_idx}.csv'
      elif lon_start_idx is not None:
        save_path = f'{save_dir}/{data_var}_lat_{lat_idx}_lon_start_{lon_start_idx}.csv'
      elif lon_end_idx is not None:
        save_path = f'{save_dir}/{data_var}_lat_{lat_idx}_lon_end_{lon_end_idx}.csv'
      else:
        save_path = f'{save_dir}/{data_var}_lat_{lat_idx}.csv'
      if os.path.exists(save_path):
        return [save_path]

  # iterate over the lon range
  # - ensure end within lon size
  start_idx = 0
  if lon_start_idx is not None:
    start_idx = lon_start_idx
  end_idx = xds.lon.size - 1
  if lon_end_idx is not None:
    end_idx = lon_end_idx
  xy_slice = xds.isel(lat=lat_idx, lon=slice(start_idx, end_idx), missing_dims='ignore')

  # year for the slice
  # - assumes there is only one
  year = int(list(set(xy_slice.time.dt.year.values))[0])
  
  pdf = (
    xy_slice
      .compute()
      .to_dataframe()
      .reset_index()
  )

  day_cols = [f'day_{dy}' for dy in range(1, 365 + 1)]
  sel_cols = ['data_var', 'year', 'lat', 'lon'] + day_cols

  pdf.rename(columns={data_var: "value"}, inplace=True)
  pdf['data_var'] = data_var
  pdf['year'] = year
  pdf['day'] = [f'day_{dy}' for dy in pdf.time.dt.dayofyear]
  pdf.drop(columns=['time'], inplace=True)

  pdf = pdf.pivot(
    index=['data_var', 'year', 'lat', 'lon'],
    columns=['day'],
    values='value'
  ).reset_index()

  pdf = pdf[sel_cols]

  # save or return dict
  if save_dir:
    with open(save_path, 'w') as f:
      f.write(pdf.to_csv(index=False))
    return save_path
  else:
    d = {}
    d['data_var'] = data_var
    d['year'] = year
    d['lat_idx'] = lat_idx
    d['lon_start_idx'] = start_idx
    d['lon_end_idx'] = end_idx
    d['vol_path'] = vol_path
    d['data'] = pdf.to_json(orient='split')
    return json.dumps(d)

In [0]:
from collections.abc import Iterator

def process_slice_pd(itr: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
  """
  Use with `mapInPandas` to add the result column.
  """
  xds = None
  curr_vol_path = ""
  vol_dir_out = ""
  filename = ""

  for pdf in itr:
    # assume just 1 value for each
    result_arr = []
    for index, row in pdf.iterrows():
      vol_path = row['vol_path']
      data_var = row['data_var']
      lat_idx = row['lat_idx']
      
      # unpack the list of results from into their own rows
      if xds is None or vol_path != curr_vol_path:
        xds= xr.open_dataset(vol_path)
        curr_vol_path = vol_path
        filename = os.path.basename(vol_path)
        vol_dir_out = f"{row['vol_dir_out']}/{filename}"
      result_arr.append(process_slice(xds, data_var, lat_idx, save_dir=vol_dir_out))
    ps = pd.DataFrame([result_arr], columns=['result'])
    yield ps.explode('result')

In [0]:
def log_msg(msg:str, debug_thresh:int, debug_mode:int):
  """
  Consolidate logging for debug mode.
  """
  if debug_mode >= debug_thresh:
    print(msg)

In [0]:
def is_config_supported():
  """
  Determine if the cluster supports Spark configuration.
  - AKA for serverless vs other.
  """
  try:
    spark.sparkContext.getConf()
    return True
  except:
    return False

print(is_config_supported())

In [0]:
def ensure_create_table(fqn_tbl:str, drop_table:bool=False, debug_mode:int=0):
  """
  Create table and optionally truncate.
  """
  create_str = ""
  if drop_table:
    log_msg(f"... create or replace table: {fqn_tbl}", 1, debug_mode)
    create_str += "create or replace table"
  else:
    create_str += "create table if not exists"
  create_str += f" {fqn_tbl} (data_var string, year int, lat double, lon double"
  for d in range(1, 365 + 1):
    create_str += f", day_{d} double"
  create_str += ");"
  log_msg(create_str, 2, debug_mode)
  sql(create_str)

In [0]:
def process_annual(
  vol_path:str, data_var:str, fqn_tbl:str, vol_dir_out:str, 
  drop_table:bool=False, skip_table:bool=False, batch_limit:int=0, debug_mode:bool=0
) -> DataFrame:
  """
  Process a netcdf containing annual data. Write to table with columns for each day of the year.
  - turning off AQE to help with repartitioning.
  - handles batching to avoid memory issues.
  - handles create table, can `drop_table` a previous table as well.
  - if `skip_table` is True, does not create table.
  Returns a DataFrame of the table if `skip_table` is False.
  """
  import xarray as xr

  # [1] get initial metadata
  # - need lat / lon sizes
  xds = xr.open_dataset(vol_path)
  n_lat = xds.sizes['lat']
  meta_rows = list(range(n_lat))
  log_msg(f"... n_lat: {n_lat:,}", 1, debug_mode)

  # [2] estimate rows
  # - including batch_limit
  if batch_limit:
    log_msg(f"... limit provided -> # batches? {batch_limit}", 1, debug_mode)
    meta_rows = meta_rows[:batch_limit]
  n_part = len(meta_rows)
  log_msg(f"... total batches estimated: {n_part:,}", 1, debug_mode)

  # [3] write batches to file
  try:
    if is_config_supported():
      spark.conf.set("spark.databricks.optimizer.adaptive.enabled", False)   # <- no coalescing
      spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", str(1)) # <- 1 row is a batch

    df_meta = (
      spark
        .createDataFrame([(m,) for m in meta_rows], ["lat_idx"])
          .withColumn("data_var", F.lit(data_var))
          .withColumn("vol_path", F.lit(vol_path))
          .withColumn("vol_dir_out", F.lit(vol_dir_out))
    )
    if debug_mode > 1:
      df_meta.show(3)

    (
      df_meta
        .repartition(n_part, F.rand())
        .mapInPandas(
          process_slice_pd,
          "result string"
        )
        .write
          .format("noop")
          .mode("overwrite")
        .save()   
    )
    log_msg(f"... finished batch(es) per lat", 1, debug_mode)
    
    # [4] write files to delta lake
    # - todo filter out already saved files
    if not skip_table:
      ensure_create_table(fqn_tbl, drop_table=drop_table, debug_mode=debug_mode)
      filename = os.path.basename(vol_path)
      (
        spark
          .read
            .csv(f"{vol_dir_out}/{filename}/*.csv", header=True, schema=csv_schema)
            .na.fill(0.0)
          .write
            .mode("append")
          .saveAsTable(fqn_tbl)
      )
    else:
      log_msg(f"... skip_table provided", 1, debug_mode)
  finally:
    if is_config_supported():
      # back to defaults
      spark.conf.set("spark.databricks.optimizer.adaptive.enabled", True)
      spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", str(10_000))
  if not skip_table:
    result_df = spark.table(fqn_tbl)
    log_msg(f"... rows in table? {result_df.count():,}", 1, debug_mode)
    debug_mode > 0 and result_df.limit(5).display()
    return result_df