In [2]:
import os
import time
import glob
import pickle
from concurrent.futures import ProcessPoolExecutor, as_completed


import pandas as pd
import xarray as xr
import numpy as np


def _convert_forecast_to_historical_units(forecast_df: pd.DataFrame):
    """Converts forecast data units to match historical data units."""
    forecast_df[["tas", "tasmax", "tasmin"]] -= 273.15  # Convert Kelvin to Celsius
    forecast_df[["rsds"]] *= 24  # Convert daily mean to total daily radiation
    return forecast_df


def _extract_point_data(
    dataset: xr.Dataset, lat: float, lon: float, variables: list[str]
):
    """Extracts time series data for specified variables at a given lat/lon."""
    df = dataset[variables].sel(lat=lat, lon=lon, method="nearest").to_dataframe()
    df = df.reset_index().rename(columns={"time": "iso-date"})
    df["iso-date"] = pd.to_datetime(df["iso-date"]).dt.date
    return df.set_index("iso-date")[variables]


def _combine_dataframes(historical_df: pd.DataFrame, forecast_df: pd.DataFrame):
    """Combines historical and forecast dataframes, removing overlaps."""
    # Find last valid index in historical data
    last_valid_historical = historical_df.isnull().all(axis=1).idxmax() - pd.Timedelta(
        days=1
    )
    historical_df = historical_df.loc[:last_valid_historical]

    combined_df = pd.concat([historical_df, forecast_df], axis=0)
    return combined_df[~combined_df.index.duplicated(keep="first")]


def _load_netcdf_data(folder: str, file_pattern: str, start_year: int, end_year: int):
    """Loads NetCDF files within the year range into an xarray dataset."""
    file_paths = []
    for year in range(start_year, end_year + 1):
        file_paths.extend(
            glob.glob(os.path.join(folder, file_pattern.format(year=year)))
        )

    if not file_paths:
        print(f"No matching files found for years {start_year} to {end_year}.")
        return None
    return xr.open_mfdataset(file_paths)

def _process_point_batch(historical_ds, forecast_ds, variables, lat_batch, lon_batch):
    """Processes a batch of points and times the execution."""
    batch_start_time = time.time()
    batch_results = {}
    for lat, lon in zip(lat_batch, lon_batch):
        historical_df = _extract_point_data(historical_ds, lat, lon, variables)
        forecast_df = _extract_point_data(forecast_ds, lat, lon, variables)
        forecast_df = _convert_forecast_to_historical_units(forecast_df)
        batch_results[f"{lat},{lon}"] = _combine_dataframes(historical_df, forecast_df)
    batch_end_time = time.time()
    print(f"Batch processed in {batch_end_time - batch_start_time:.2f} seconds")
    return batch_results


def extract_variables_over_years(
    netcdf_folder: str,
    variables: list[str],
    file_name_pattern: str,
    start_year: int,
    end_year: int,
    mask_path: str,
    batch_size: int = None,
):
    """Extracts and combines data for valid grid points, processing in batches."""
    with open(mask_path, "rb") as f:
        mask = np.array(pickle.load(f))

    # Calculate batch size if not provided
    if batch_size is None:
        num_processes = os.cpu_count() or 1  # Get CPU count (default to 1)
        num_valid_points = np.count_nonzero(mask) 
        batch_size = max(1, num_valid_points // num_processes) 

    historical_ds = _load_netcdf_data(
        netcdf_folder, file_name_pattern, start_year, end_year
    )
    forecast_ds = _load_netcdf_data(
        "netcdf_files/forecasts_2024_05/r1i1p1", "combined.nc", 2024, 2024
    )

    if not historical_ds or not forecast_ds:
        return {}

    lats, lons = historical_ds["lat"].values, historical_ds["lon"].values
    lon_grid, lat_grid = np.meshgrid(lons, lats)
    valid_lat_indices, valid_lon_indices = np.where(mask)

    df_dict = {}
    start_time = time.time()

    with ProcessPoolExecutor() as executor:
        # Submit all batches as futures
        futures = []
        for i in range(0, len(valid_lat_indices), batch_size):
            lat_batch = lat_grid[
                valid_lat_indices[i : i + batch_size],
                valid_lon_indices[i : i + batch_size],
            ]
            lon_batch = lon_grid[
                valid_lat_indices[i : i + batch_size],
                valid_lon_indices[i : i + batch_size],
            ]
            futures.append(
                executor.submit(
                    _process_point_batch,
                    historical_ds,
                    forecast_ds,
                    variables,
                    lat_batch,
                    lon_batch,
                )
            )

        # Collect results from completed futures
        for future in as_completed(futures):
            df_dict.update(future.result())

    historical_ds.close()
    forecast_ds.close()
    print(f"Processed {len(df_dict)} points in {time.time() - start_time:.2f} seconds.")
    return df_dict

In [3]:
df_dict = extract_variables_over_years(
    netcdf_folder="netcdf_files/combined",
    variables=["hurs", "pr", "rsds", "sfcWind", "tas", "tasmax", "tasmin"],
    file_name_pattern="zalf_combined_amber_{year}_v1-0_uncompressed.nc",
    start_year=2023,
    end_year=2024,
    mask_path="data_availability_mask_h_and_fc.pkl",
)
print(len(df_dict))

In [3]:
print(df_dict)

{'54.92615337468216,8.356560653521498':                  hurs         pr         rsds    sfcWind        tas  \
iso-date                                                              
2023-01-01  93.000000   0.500000   443.799988  11.000000   8.900001   
2023-01-02  94.200005   6.100000   268.000000   5.700000   7.000000   
2023-01-03  84.900002   3.200000   696.599976   7.000000   5.600000   
2023-01-04  95.900002  15.800000   110.000000  12.200000   8.000000   
2023-01-05  86.800003   0.800000   418.000031   7.800000   5.800000   
...               ...        ...          ...        ...        ...   
2024-10-27  95.700569   0.000000  1096.704712   3.500567  10.366455   
2024-10-28  94.465347   0.000142   966.409058   7.234089  12.899994   
2024-10-29  90.264778   0.000001  1033.813232   6.000567  11.633545   
2024-10-30  90.400002   0.000000   564.981506   4.733522   8.666473   
2024-10-31  93.766479   0.000000   340.063538   6.200567   9.633545   

               tasmax     tasmin  
i