## Rough Implementation of Data Cleaning and Joining Module - Notebook

### Importing Libraries

In [9]:
%matplotlib inline

from __future__ import annotations

# ────────────────────────────────────────────────────────────────────────────
# Data Manipulation & Analysis
# ─────────────────────────────────────────────────────────────────────────────
import pandas as pd
import polars as pl
import numpy as np

# ─────────────────────────────────────────────────────────────────────────────
# Geospatial Data Handling
# ─────────────────────────────────────────────────────────────────────────────
import xarray as xr
from xarray.coding.times import CFTimedeltaCoder
import cfgrib
# import xesmf as xe

# ─────────────────────────────────────────────────────────────────────────────
# Notebook/Display Tools
# ─────────────────────────────────────────────────────────────────────────────
from IPython.display import display
import matplotlib.pyplot as plt
import folium
from folium.plugins import MarkerCluster
from folium.features import RegularPolygonMarker

# ─────────────────────────────────────────────────────────────────────────────
# System / Miscellaneous
# ─────────────────────────────────────────────────────────────────────────────
import os
from pathlib import Path
import shutil
import logging
import re
import concurrent.futures

import math
from functools import reduce

from zoneinfo import ZoneInfo
from datetime import datetime, timedelta

from typing import List, Optional, Dict, Set, Tuple
from collections import defaultdict

import requests
from bs4 import BeautifulSoup

from eccodes import (
    codes_grib_new_from_file,
    codes_get,
    codes_release,
    CodesInternalError,
)

from selenium import webdriver
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.chrome.options import Options
import time
import pandas as pd


### Initial Setup (directories)

In [10]:
cwd = os.getcwd()
print("-"*120)
print("Current Working Directory and contents:\n"+"-"*120)
for root, dirs, files in os.walk(cwd):
    print(f"\nDirectory: {root}")
    print(f"Subdirectories: {dirs}\n"+ "-"*40)
    for file in sorted(files):
        print(f"-> File: {file}")


------------------------------------------------------------------------------------------------------------------------
Current Working Directory and contents:
------------------------------------------------------------------------------------------------------------------------

Directory: /Users/Daniel/Desktop/dev-weather-data-processing
Subdirectories: ['.git']
----------------------------------------
-> File: README.md
-> File: data_cleaning_development.ipynb

Directory: /Users/Daniel/Desktop/dev-weather-data-processing/.git
Subdirectories: ['objects', 'info', 'logs', 'hooks', 'refs']
----------------------------------------
-> File: HEAD
-> File: config
-> File: description
-> File: index
-> File: packed-refs

Directory: /Users/Daniel/Desktop/dev-weather-data-processing/.git/objects
Subdirectories: ['pack', 'info']
----------------------------------------

Directory: /Users/Daniel/Desktop/dev-weather-data-processing/.git/objects/pack
Subdirectories: []
--------------------------

In [11]:
# DIRECTORIES AND PATHS
root_directory = Path.cwd().parent
print(root_directory)

data_directory = root_directory / "data"
print(data_directory)

data_raw_directory = data_directory / "raw"
print(data_raw_directory)

print("contents of raw data directory:")
for item in data_raw_directory.iterdir():
    print(item)

/Users/Daniel/Desktop
/Users/Daniel/Desktop/data
/Users/Daniel/Desktop/data/raw
contents of raw data directory:
/Users/Daniel/Desktop/data/raw/era5-world_N37W68S6E98_d514a3a3c256_2019_08.grib
/Users/Daniel/Desktop/data/raw/era5-world_N37W68S6E98_d514a3a3c256_2024_03.grib
/Users/Daniel/Desktop/data/raw/era5-world_N37W68S6E98_d514a3a3c256_2020_01.grib
/Users/Daniel/Desktop/data/raw/era5-world_N37W68S6E98_d514a3a3c256_2019_12.grib
/Users/Daniel/Desktop/data/raw/era5-world_N37W68S6E98_d514a3a3c256_2018_01.grib.5b7b6.idx
/Users/Daniel/Desktop/data/raw/.DS_Store
/Users/Daniel/Desktop/data/raw/era5-world_N37W68S6E98_d514a3a3c256_2025_04.grib
/Users/Daniel/Desktop/data/raw/era5-world_N37W68S6E98_d514a3a3c256_2018_02.grib
/Users/Daniel/Desktop/data/raw/era5-world_N37W68S6E98_d514a3a3c256_2024_04.grib
/Users/Daniel/Desktop/data/raw/era5-world_N37W68S6E98_d514a3a3c256_2025_06.grib
/Users/Daniel/Desktop/data/raw/era5-world_N37W68S6E98_d514a3a3c256_2018_01.grib
/Users/Daniel/Desktop/data/raw/era5-w

### Code

#### Part 1 - Identify and Select Files

##### Functions - Backend

In [12]:
def parse_filename(fname: str) -> Optional[tuple[str, int, int, str]]:
    """
    Parse ERA5-style filenames of the form:
        {prefix}_{YYYY}_{MM}[.ext]

    Assumes:
    - The filename contains exactly one dot before the extension.
    - The last two underscore-separated tokens before the dot are YYYY and MM.

    Parameters
    ----------
    fname: str
        Filename to parse.

    Returns
    -------
    tuple[str, int, int, str] | None
        prefix (str), year (int), month (int), extension (str) if parsing is successful, else None.
    """
    # Split into basename and extension
    root, extension = fname.split(".", 1)

    # Split root into prefix parts
    parts = root.split("_")

    # Need at least: prefix_part(s) + YYYY + MM
    if len(parts) < 3:
        return None

    year_str = parts[-2]
    month_str = parts[-1]

    # Validate numeric year/month
    if not (year_str.isdigit() and month_str.isdigit()):
        return None

    prefix = "_".join(parts[:-2])
    year = int(year_str)
    month = int(month_str)

    return prefix, year, month, extension

In [13]:
def scan_directory(raw_dir: Path) -> Dict[str, Dict[int, List[Path]]]:
    """
    Scan directory (assuming raw data) and build structured mapping using the format
    assumed by parse_filename() function (key, year, month, extension).

    Key format for mapping: tuple(prefix, extension)

    Sample Output structure:
    {
        ("prefix1", "grib"): {
             2018: [Path(...), Path(...)],
             2019: [...],
        },
        ("prefix2", "nc"): { ... }
    }

    Parameters
    ----------
    raw_dir: Path
        Path to the raw data directory to be scanned.

    Returns
    -------
    dict[str, dict[int, list[Path]]]
        Nested dictionary mapping dataset keys to years and lists of file paths.
        [str: dataset key, int: year, list[Path]: list of file paths]
    """
    mapping: Dict[str, Dict[int, List[Path]]] = defaultdict(lambda: defaultdict(list))

    for entry in raw_dir.iterdir():
        if entry.is_file() and not entry.name.startswith("."):
            parsed = parse_filename(entry.name)
            if parsed is None:
                continue

            prefix, year, month, extension = parsed
            dataset_key = f"{prefix} ({extension})"
            mapping[dataset_key][year].append(entry)

    return mapping

In [14]:
def print_dataset_summary(
        mapping: dict,
        directory: str
        ) -> None:
    """
    Print the dataset prefixes, years, and months in the raw directory.

    Parameters
    ----------
    mapping: dict
        Nested dictionary mapping dataset keys to years and lists of file paths.

    Returns
    -------
    None
        Prints the summary to console.
    """
    print(f"\nDetected the following datasets in [{directory}]:\n")

    for dataset_key, years in mapping.items():
        print("-"*40 + f"\n{dataset_key}\n" + "-"*40)
        print(f"\tYears found: {', '.join(str(y) for y in sorted(years))}")

        for year in sorted(years):
            months = []
            for path in years[year]:
                _, _, month, _ = parse_filename(path.name)
                months.append(month)

            mon_str = ", ".join(f"{m:02d}" for m in sorted(months))
            print(f"\t\t{year}: {mon_str}")
        print()

In [15]:
def scan_and_optionally_print(
        raw_dir: Path,
        print_summary: bool = True
        ) -> dict[str, dict[int, list[Path]]]:
    """
    Combined function to scan the directory and optionally print the results.

    Parameters
    ----------
    raw_dir : Path
        Directory containing raw ERA5 files
    print_summary : bool
        If True, prints the summary of datasets found

    Returns
    -------
    dict[str, dict[int, list[Path]]]
        Mapping of dataset_key → years → list of files
    """
    mapping = scan_directory(raw_dir)

    if print_summary:
        print_dataset_summary(mapping, directory=raw_dir)

    return mapping

##### Functions - Prompts

In [16]:
def prompt_user_for_directory(default_dir: str = "data/raw") -> Path:
    """
    Prompt the user to specify a directory relative to the project root.
    Validates that the directory exists. Falls back to default if desired.

    Parameters
    ----------
    default_dir : str
        The default directory to use if the user just presses ENTER.

    Returns
    -------
    Path
        A valid directory path pointing to the selected folder.
    """
    print("\nPlease enter the relative directory you want to scan.")
    print(f"Press ENTER to use the default: [{default_dir}]")
    print("Example inputs: data/raw, data/new_data, datasets/era5, etc.\n")

    while True:
        user_input = input("Directory: ").strip()

        # User pressed ENTER → use default
        if user_input == "":
            chosen = Path(default_dir)
        else:
            chosen = Path(user_input)

        if chosen.exists() and chosen.is_dir():
            print(f"\nUsing directory: {chosen.resolve()}\n")
            return chosen
        else:
            print(f"[ERROR] Directory '{chosen}' does not exist. Please try again.\n")


In [17]:
def prompt_user_for_dataset_key(mapping: dict) -> str:
    """
    Prompt the user to select one dataset key from the provided mapping.

    This function displays all available dataset identifiers (keys) generated
    by ``scan_directory`` or ``scan_and_optionally_print``. The user may select
    a dataset either by typing the full dataset key or by entering the
    corresponding number shown in the list.

    Parameters
    ----------
    mapping : dict
        A nested dictionary of the form::

            {
                "prefix_x (ext)": {
                    2018: [Path(...), Path(...), ...],
                    2019: [...],
                    ...
                },
                "prefix_y (ext)": { ... },
                ...
            }

        Keys correspond to dataset identifiers of the form
        ``"{prefix} ({extension})"``.

        Values are dictionaries mapping integer years to lists of ``Path`` objects
        referencing the monthly dataset files.

    Returns
    -------
    str
        The dataset key chosen by the user. This string will match exactly one
        of the keys present in ``mapping``.

    Notes
    -----
    - Input is validated to ensure that only valid selections are accepted.
    - The function loops until a correct response is provided.
    - This function is intended for CLI or notebook interactive workflows.
    - The returned key can be used to index into ``mapping`` to access
      the associated file paths.

    Examples
    --------
    >>> mapping = {
    ...     "era5-delhi (grib)": {2020: [Path("file1")], 2021: [Path("file2")]},
    ...     "era5-mumbai (grib)": {2019: [Path("file3")]}
    ... }
    >>> key = prompt_user_for_dataset_key(mapping)
    Available dataset prefixes:
      1. era5-delhi (grib)
      2. era5-mumbai (grib)
    Please enter the FULL dataset key above, or enter its number: 1
    You selected: era5-delhi (grib)
    >>> key
    'era5-delhi (grib)'
    """
    dataset_keys = sorted(mapping.keys())

    print("\nAvailable dataset prefixes:")
    for i, key in enumerate(dataset_keys, start=1):
        print(f"  {i}. {key}")

    while True:
        selection = input(
            "\nPlease enter the FULL dataset key above, or enter its number: "
        ).strip()

        # Case 1: user entered a number
        if selection.isdigit():
            idx = int(selection) - 1
            if 0 <= idx < len(dataset_keys):
                chosen = dataset_keys[idx]
                print(f"\nYou selected: {chosen}\n")
                return chosen
            else:
                print("[ERROR] Invalid number. Try again.")

        # Case 2: user entered full string dataset key
        elif selection in dataset_keys:
            print(f"\nYou selected: {selection}\n")
            return selection

        # Case 3: invalid input
        else:
            print("[ERROR] Invalid input. Please try again.")

##### Execution

In [18]:
# --- MAIN EXECUTION WORKFLOW ---

# 1. Ask the user which directory to scan
RAW_DIR = prompt_user_for_directory(default_dir="data/raw")

# 2. Scan the directory and optionally print a summary
mapping = scan_and_optionally_print(RAW_DIR, print_summary=True)

# 3. Prompt user to choose which dataset key to process
selected_key = prompt_user_for_dataset_key(mapping)

print(f"You will now process dataset: {selected_key}")



Please enter the relative directory you want to scan.
Press ENTER to use the default: [data/raw]
Example inputs: data/raw, data/new_data, datasets/era5, etc.

[ERROR] Directory 'data/raw' does not exist. Please try again.


Using directory: /Users/Daniel/Desktop/data/raw


Detected the following datasets in [../data/raw]:

----------------------------------------
era5-world_N37W68S6E98_d514a3a3c256 (grib)
----------------------------------------
	Years found: 2018, 2019, 2020, 2023, 2024, 2025
		2018: 01, 02
		2019: 08, 12
		2020: 01
		2023: 02, 12
		2024: 01, 03, 04
		2025: 04, 06

----------------------------------------
era5-world_N37W68S6E98_d514a3a3c256 (grib.5b7b6.idx)
----------------------------------------
	Years found: 2018
		2018: 01


Available dataset prefixes:
  1. era5-world_N37W68S6E98_d514a3a3c256 (grib)
  2. era5-world_N37W68S6E98_d514a3a3c256 (grib.5b7b6.idx)

You selected: era5-world_N37W68S6E98_d514a3a3c256 (grib)

You will now process dataset: era5-world_N37W68S6

#### Part 2 - Scanning File Contents & Processing Preparation

##### Functions - Backend

In [19]:
def get_first_file_for_dataset(files_by_year: Dict[int, List[Path]]) -> Path:
    """
    Return the first available file (by sorted year, then path) for a dataset.

    Parameters
    ----------
    files_by_year : dict[int, list[pathlib.Path]]
        Mapping of year → list of file paths.

    Returns
    -------
    pathlib.Path
        Path to the first file.

    Raises
    ------
    ValueError
        If no files are available.
    """
    if not files_by_year:
        raise ValueError("No files available for the dataset.")

    first_year = min(files_by_year.keys())
    if not files_by_year[first_year]:
        raise ValueError(f"No files found for year {first_year}.")

    return sorted(files_by_year[first_year])[0]


In [20]:
path = get_first_file_for_dataset(mapping[selected_key])
print(path)

../data/raw/era5-world_N37W68S6E98_d514a3a3c256_2018_01.grib


In [21]:
def scan_variables_in_file(path: Path) -> list[dict]:
    """
    Scan a GRIB file with ecCodes and return a list of dictionaries:
    [
        { "paramId": 167, "shortName": "2t" },
        { "paramId": 228, "shortName": "tp" },
        ...
    ]

    - Uses low-level eccodes API.
    - Works for ANY ERA5 world file.
    - Robust to mixed editions, hours, time steps.
    """
    variables = {}
    # Use dict keyed by paramId to dedupe

    try:
        with open(path, "rb") as f:
            while True:
                try:
                    gid = codes_grib_new_from_file(f)
                except CodesInternalError:
                    break
                if gid is None:
                    break

                try:
                    short_name = codes_get(gid, "shortName")
                    param_id = codes_get(gid, "paramId")

                    # decode if bytes
                    if isinstance(short_name, bytes):
                        short_name = short_name.decode("utf-8")

                    variables[param_id] = {
                        "paramId": param_id,
                        "shortName": short_name,
                    }

                finally:
                    codes_release(gid)

    except FileNotFoundError:
        print(f"File not found: {path}")
    except Exception as e:
        print(f"Unexpected error while scanning {path}: {e}")

    return list(variables.values())


In [22]:

def setup_driver():
    """Setup headless Chrome driver"""
    chrome_options = Options()
    chrome_options.add_argument("--headless")
    chrome_options.add_argument("--no-sandbox")
    chrome_options.add_argument("--disable-dev-shm-usage")
    driver = webdriver.Chrome(options=chrome_options)
    return driver

def get_parameter_details_selenium(param_id: int, driver) -> Dict:
    """
    Get parameter details using Selenium to render JavaScript
    """
    url = f"https://codes.ecmwf.int/grib/param-db/{param_id}/"

    try:
        driver.get(url)

        # Wait for the content to load
        WebDriverWait(driver, 10).until(
            EC.presence_of_element_located((By.XPATH, "//td[contains(., 'Name')]"))
        )

        details = {}

        # Extract Name
        try:
            name_cell = driver.find_element(By.XPATH, "//td[p[contains(., 'Name')]]/following-sibling::td")
            details['name'] = name_cell.find_element(By.TAG_NAME, "p").text
        except:
            details['name'] = ''

        # Extract Unit
        try:
            unit_cell = driver.find_element(By.XPATH, "//td[p[contains(., 'Unit')]]/following-sibling::td")
            details['unit'] = unit_cell.find_element(By.TAG_NAME, "p").text
        except:
            details['unit'] = ''

        # Extract Description
        try:
            desc_cell = driver.find_element(By.XPATH, "//td[p[contains(., 'Description')]]/following-sibling::td")
            details['description'] = desc_cell.find_element(By.TAG_NAME, "p").text
        except:
            details['description'] = ''

        return details

    except Exception as e:
        print(f"Error fetching param {param_id}: {e}")
        return {'name': '', 'unit': '', 'description': ''}

def enrich_parameters_selenium(parameter_list: List[Dict]) -> List[Dict]:
    """
    Enrich parameter list using Selenium
    """
    driver = setup_driver()
    enriched = []

    try:
        for i, param in enumerate(parameter_list):
            param_id = param['paramId']
            short_name = param['shortName']

            print(f"Processing {short_name} (ID: {param_id}) - {i+1}/{len(parameter_list)}")

            details = get_parameter_details_selenium(param_id, driver)

            # Create enriched parameter
            enriched_param = param.copy()
            enriched_param.update({
                'full_name': details.get('name', ''),
                'units': details.get('unit', ''),
                'description': details.get('description', '')
            })

            enriched.append(enriched_param)

            # Be polite to the server
            time.sleep(1)

    finally:
        driver.quit()

    return enriched

In [23]:
from typing import List, Dict
import time

def enrich_variable_list_selenium(vars_list: List[Dict]) -> List[Dict]:
    """
    Enrich a list of GRIB variable dicts using Selenium to extract metadata.

    Parameters
    ----------
    vars_list : List[Dict]
        List of dicts with 'paramId' and 'shortName'.

    Returns
    -------
    List[Dict]
        Enriched variable list with full name, units, and description.
    """
    driver = setup_driver()
    enriched = []

    try:
        for i, item in enumerate(vars_list):
            param_id = item["paramId"]
            short_name = item["shortName"]
            print(f"Processing {short_name} (ID: {param_id}) - {i+1}/{len(vars_list)}")

            try:
                details = get_parameter_details_selenium(param_id, driver)
            except Exception as e:
                print(f"❌ Failed to fetch details for paramId {param_id}: {e}")
                details = {'name': '', 'unit': '', 'description': ''}

            enriched_item = {
                "paramId": param_id,
                "shortName": short_name,
                "fullName": details.get("name", ''),
                "units": details.get("unit", ''),
                "description": details.get("description", ''),
                "url": f"https://codes.ecmwf.int/grib/param-db/{param_id}/"
            }
            enriched.append(enriched_item)

            # Respectful delay to avoid hammering ECMWF's servers
            time.sleep(1)

    finally:
        driver.quit()

    return enriched


In [24]:
vars = scan_variables_in_file(path)
print(vars)

[{'paramId': 167, 'shortName': '2t'}, {'paramId': 228, 'shortName': 'tp'}, {'paramId': 165, 'shortName': '10u'}, {'paramId': 166, 'shortName': '10v'}, {'paramId': 228246, 'shortName': '100u'}, {'paramId': 228247, 'shortName': '100v'}, {'paramId': 228022, 'shortName': 'cdir'}, {'paramId': 57, 'shortName': 'uvb'}, {'paramId': 176, 'shortName': 'ssr'}, {'paramId': 210, 'shortName': 'ssrc'}, {'paramId': 177, 'shortName': 'str'}, {'paramId': 211, 'shortName': 'strc'}, {'paramId': 228129, 'shortName': 'ssrdc'}, {'paramId': 169, 'shortName': 'ssrd'}, {'paramId': 228130, 'shortName': 'strdc'}, {'paramId': 175, 'shortName': 'strd'}, {'paramId': 178, 'shortName': 'tsr'}, {'paramId': 208, 'shortName': 'tsrc'}, {'paramId': 179, 'shortName': 'ttr'}, {'paramId': 209, 'shortName': 'ttrc'}, {'paramId': 228021, 'shortName': 'fdir'}, {'paramId': 188, 'shortName': 'hcc'}, {'paramId': 186, 'shortName': 'lcc'}, {'paramId': 187, 'shortName': 'mcc'}, {'paramId': 164, 'shortName': 'tcc'}, {'paramId': 28, 'sho

In [25]:
enriched = enrich_variable_list_selenium(vars)

print("\n✅ Enriched variable metadata:")
for v in enriched:
    print(v)

Processing 2t (ID: 167) - 1/30
Processing tp (ID: 228) - 2/30
Processing 10u (ID: 165) - 3/30
Processing 10v (ID: 166) - 4/30
Processing 100u (ID: 228246) - 5/30
Processing 100v (ID: 228247) - 6/30
Processing cdir (ID: 228022) - 7/30
Processing uvb (ID: 57) - 8/30
Processing ssr (ID: 176) - 9/30
Processing ssrc (ID: 210) - 10/30
Processing str (ID: 177) - 11/30
Processing strc (ID: 211) - 12/30
Processing ssrdc (ID: 228129) - 13/30
Processing ssrd (ID: 169) - 14/30
Processing strdc (ID: 228130) - 15/30
Processing strd (ID: 175) - 16/30
Processing tsr (ID: 178) - 17/30
Processing tsrc (ID: 208) - 18/30
Processing ttr (ID: 179) - 19/30
Processing ttrc (ID: 209) - 20/30
Processing fdir (ID: 228021) - 21/30
Processing hcc (ID: 188) - 22/30
Processing lcc (ID: 186) - 23/30
Processing mcc (ID: 187) - 24/30
Processing tcc (ID: 164) - 25/30
Processing cvh (ID: 28) - 26/30
Processing lai_hv (ID: 67) - 27/30
Processing lai_lv (ID: 66) - 28/30
Processing cvl (ID: 27) - 29/30
Processing kx (ID: 26

In [26]:
def get_variables_from_first_file(files_by_year: Dict[int, List[Path]]) -> Set[str]:
    """
    Use the first available file to detect ERA5 variable names.

    Parameters
    ----------
    files_by_year : dict[int, list[pathlib.Path]]
        Mapping of year → list of file paths.

    Returns
    -------
    set[str]
        Set of variable names found in the first file.
    """
    first_file = get_first_file_for_dataset(files_by_year)
    print(f"Using first file for variable scan: {first_file}")
    return scan_variables_in_file(first_file)


In [27]:
def print_variable_list(vars_found: Set[str]) -> None:
    """
    Pretty-print the list of detected variable names.

    Parameters
    ----------
    vars_found : set[str]
        Variable names to print.

    Returns
    -------
    None
    """
    print("\nDetected the following variables in the first file:\n")
    for v in sorted(vars_found):
        print(f"  • {v}")
    print(f"\nTotal variables: {len(vars_found)}\n")


In [28]:
def verify_variables_across_files(
        files_by_year: Dict[int, List[Path]],
        reference_vars: Set[str]
        ) -> None:
    """
    Optionally verify that all files share the same set of variable names.

    Parameters
    ----------
    files_by_year : dict[int, list[pathlib.Path]]
        Mapping of year → list of file paths.
    reference_vars : set[str]
        Reference set of variables from the first file.

    Returns
    -------
    None

    Notes
    -----
    Prints warnings if any file is found with a different variable set.
    """
    print("\nVerifying that all files share the same variable structure...\n")
    mismatches = []

    for year, paths in files_by_year.items():
        for path in paths:
            vars_this = scan_variables_in_file(path)
            if vars_this and vars_this != reference_vars:
                diff_ref = reference_vars - vars_this
                diff_this = vars_this - reference_vars
                mismatches.append((path, diff_ref, diff_this))

    if not mismatches:
        print("✅ All checked files share the same variable set.\n")
    else:
        print("⚠ Detected files with differing variable sets:\n")
        for path, missing_from_file, extra_in_file in mismatches:
            print(f"File: {path}")
            if missing_from_file:
                print("  Missing (present in reference but not in this file):")
                for v in sorted(missing_from_file):
                    print(f"    - {v}")
            if extra_in_file:
                print("  Extra (present in this file but not in reference):")
                for v in sorted(extra_in_file):
                    print(f"    - {v}")
            print()

        cont = prompt_yes_no("Differences detected. Continue processing anyway?",
                             default="n")
        if not cont:
            raise RuntimeError("Aborting due to variable set mismatches.")

In [29]:
def make_period_ids_for_year(
        year: int,
        agg_level: str
        ) -> List[Tuple[int, Optional[str]]]:
    """
    Generate (month, period_id) pairs for the given year and aggregation level.

    This is used to decide how to slice the data and how to name files.

    Parameters
    ----------
    year : int
        Year of interest.
    agg_level : {'annual', 'half-year', 'quarterly', 'monthly'}
        Aggregation level.

    Returns
    -------
    list of (int, str or None)
        Each entry is a tuple (month, period_id), where:
        - For 'annual':    one entry with (None, None)
        - For 'half-year': months are grouped into 'A1' (Jan–Jun), 'A2' (Jul–Dec)
        - For 'quarterly': months grouped into 'Q1'..'Q4'
        - For 'monthly':   one period per month with 'M01'..'M12'

    Notes
    -----
    For processing convenience we primarily use this to derive the period
    identifier (like 'A1', 'Q3', 'M05') and then filter rows by month ranges.
    """
    if agg_level == "annual":
        return [(None, None)]  # single period for whole year

    if agg_level == "monthly":
        return [(m, f"M{m:02d}") for m in range(1, 13)]

    if agg_level == "half-year":
        # A1: Jan–Jun, A2: Jul–Dec
        return [
            (1, "A1"),  # we use the month just as a handle here
            (7, "A2"),
        ]

    if agg_level == "quarterly":
        return [
            (1, "Q1"),   # Jan–Mar
            (4, "Q2"),   # Apr–Jun
            (7, "Q3"),   # Jul–Sep
            (10, "Q4"),  # Oct–Dec
        ]

    raise ValueError(f"Unknown aggregation level: {agg_level}")


In [30]:
def make_output_filename(
        prefix: str,
        year: int,
        agg_level: str,
        period_id: Optional[str]
        ) -> str:
    """
    Construct an output filename for a given prefix/year/aggregation/period.

    Examples
    --------
    Annual:
        prefix_annual_2020.parquet

    Half-year:
        prefix_halfyear_2020_A1.parquet

    Quarterly:
        prefix_quarterly_2020_Q3.parquet

    Monthly:
        prefix_monthly_2020_M07.parquet

    Parameters
    ----------
    prefix : str
        Dataset prefix (no extension and no spaces/parentheses).
    year : int
        Year of the data.
    agg_level : {'annual', 'half-year', 'quarterly', 'monthly'}
        Aggregation/split level.
    period_id : str or None
        Period identifier (e.g., 'A1', 'Q3', 'M07'); None for annual.

    Returns
    -------
    str
        Filename (without directory), always ending in '.parquet'.
    """
    if agg_level == "annual":
        return f"{prefix}_annual_{year}.parquet"

    if agg_level == "half-year":
        if period_id is None:
            raise ValueError("period_id must not be None for half-year.")
        return f"{prefix}_halfyear_{year}_{period_id}.parquet"

    if agg_level == "quarterly":
        if period_id is None:
            raise ValueError("period_id must not be None for quarterly.")
        return f"{prefix}_quarterly_{year}_{period_id}.parquet"

    if agg_level == "monthly":
        if period_id is None:
            raise ValueError("period_id must not be None for monthly.")
        return f"{prefix}_monthly_{year}_{period_id}.parquet"

    raise ValueError(f"Unknown aggregation level: {agg_level}")


##### Functions - Prompts

In [31]:
def prompt_yes_no(
        message: str,
        default: str = "y"
        ) -> bool:
    """
    Generic yes/no prompt.

    Parameters
    ----------
    message : str
        Question to display to the user.
    default : {'y', 'n'}, optional
        Default answer if the user just presses ENTER.

    Returns
    -------
    bool
        True if user answers yes, False otherwise.
    """
    default = default.lower()
    assert default in ("y", "n")

    suffix = "[Y/n]" if default == "y" else "[y/N]"

    while True:
        ans = input(f"{message} {suffix}: ").strip().lower()
        if ans == "" and default in ("y", "n"):
            return default == "y"
        if ans in ("y", "yes"):
            return True
        if ans in ("n", "no"):
            return False
        print("❌ Please answer 'y' or 'n'.")


In [32]:
def prompt_for_columns_to_drop(default_drop: Optional[List[str]] = None) -> List[str]:
    """
    Prompt the user to confirm or override the list of columns to drop.

    Parameters
    ----------
    default_drop : list of str or None, optional
        Default list of columns to drop. If None or empty, no columns are
        dropped unless the user specifies them manually.

    Returns
    -------
    list[str]
        Final list of columns to drop.
    """
    default_drop = default_drop or []

    if default_drop:
        print("\nThe following columns are currently marked for dropping:")
        for col in default_drop:
            print(f"  - {col}")
    else:
        print("\nNo default columns are marked for dropping.")

    user_input = input(
        "\nPress ENTER to accept this selection, or enter a comma-separated "
        "list of column names to drop instead: "
    ).strip()

    if user_input == "":
        final_drop = default_drop
    else:
        final_drop = [c.strip() for c in user_input.split(",") if c.strip()]

    print("\nFinal list of columns to drop:")
    if final_drop:
        for col in final_drop:
            print(f"  - {col}")
    else:
        print("  (None)")
    print()

    return final_drop

In [33]:
def prompt_for_aggregation_level() -> str:
    """
    Prompt the user to choose an aggregation/split level.

    Options
    -------
    - 'annual'
    - 'half-year'
    - 'quarterly'
    - 'monthly'

    Returns
    -------
    str
        One of {'annual', 'half-year', 'quarterly', 'monthly'}.
    """
    valid = {"annual", "half-year", "quarterly", "monthly"}

    print("\nPlease select the aggregation/split level for output files:")
    print("  - annual")
    print("  - half-year")
    print("  - quarterly")
    print("  - monthly")

    while True:
        choice = input("\nEnter aggregation level: ").strip().lower()
        if choice in valid:
            print(f"\nYou selected: {choice}\n")
            return choice
        print("❌ Invalid choice. Please enter one of: annual, half-year, quarterly, monthly.")


##### Execution

In [None]:
get_first_file_for_dataset(mapping[selected_key])
scan_variables_in_file(path)
get_variables_from_first_file(mapping[selected_key])
print_variable_list(get_variables_from_first_file(mapping[selected_key]))

<function __main__.print_variable_list(vars_found: 'Set[str]') -> 'None'>

#### Part 3 - Data Processing - Execution

##### Functions

In [35]:
def load_grib_file_to_polars(path: Path) -> pl.DataFrame:
    """
    Load a single GRIB file using xarray and convert it to a Polars DataFrame.

    Parameters
    ----------
    path : pathlib.Path
        Path to the GRIB file.

    Returns
    -------
    polars.DataFrame
        DataFrame containing all variables and coordinates.

    Notes
    -----
    - Uses xarray with cfgrib engine.
    - Assumes presence of a 'time' coordinate or a 'valid_time' coordinate.
    """
    ds = xr.open_dataset(
        path,
        engine="cfgrib",
        backend_kwargs={"indexpath": ""}
    )

    # Flatten to pandas then to polars
    pdf = ds.to_dataframe().reset_index()
    ds.close()

    pl_df = pl.from_pandas(pdf)

    # Normalise time column
    if "time" in pl_df.columns:
        time_col = "time"
    elif "valid_time" in pl_df.columns:
        time_col = "valid_time"
    else:
        raise ValueError(f"No 'time' or 'valid_time' column found in {path}.")

    pl_df = pl_df.with_columns(
        pl.col(time_col).alias("timestamp")
    ).drop(time_col)

    # Ensure timestamp is a proper datetime type
    pl_df = pl_df.with_columns(
        pl.col("timestamp").cast(pl.Datetime(time_unit="us"), strict=False)
    )

    return pl_df


In [36]:
def combine_year_files(
        files_by_year: Dict[int, List[Path]],
        year: int
        ) -> pl.LazyFrame:
    """
    Load and vertically concatenate all files for a given year into a LazyFrame.

    Parameters
    ----------
    files_by_year : dict[int, list[pathlib.Path]]
        Mapping of year → list of file paths.
    year : int
        Year to combine.

    Returns
    -------
    polars.LazyFrame
        Combined LazyFrame containing all rows for that year.
    """
    paths = files_by_year.get(year, [])
    if not paths:
        raise ValueError(f"No files for year {year}.")

    dfs: List[pl.DataFrame] = []
    for p in sorted(paths):
        print(f"   Loading {p} ...")
        df = load_grib_file_to_polars(p)
        dfs.append(df)

    combined = pl.concat(dfs, how="vertical")
    return combined.lazy()


In [37]:
def partition_and_save_year(
        lf_year: pl.LazyFrame,
        prefix: str,
        year: int,
        agg_level: str,
        drop_cols: List[str],
        output_dir: Path
        ) -> None:
    """
    Partition a year's LazyFrame by the chosen aggregation level and save
    each partition to a separate parquet file.

    Parameters
    ----------
    lf_year : polars.LazyFrame
        LazyFrame containing all data for a given year.
    prefix : str
        Dataset prefix (no extension, no parentheses).
    year : int
        Year of the data.
    agg_level : {'annual', 'half-year', 'quarterly', 'monthly'}
        Aggregation/split level.
    drop_cols : list[str]
        Columns to drop before writing.
    output_dir : pathlib.Path
        Output directory for parquet files.

    Returns
    -------
    None
    """
    # Add year and month derived from timestamp
    lf = lf_year.with_columns([
        pl.col("timestamp").dt.year().alias("year"),
        pl.col("timestamp").dt.month().alias("month"),
    ])

    # Drop columns if requested
    if drop_cols:
        # Filter drop_cols to those that actually exist
        cols_in_schema = set(lf.collect_schema().names())
        actual_drop = [c for c in drop_cols if c in cols_in_schema]
        if actual_drop:
            lf = lf.drop(actual_drop)
        else:
            print("Warning: none of the requested drop columns exist in this schema.")

    output_dir.mkdir(parents=True, exist_ok=True)

    period_specs = make_period_ids_for_year(year, agg_level)

    if agg_level == "annual":
        # One file for the whole year
        fname = make_output_filename(prefix, year, agg_level, period_id=None)
        out_path = output_dir / fname
        print(f"   Writing annual file: {out_path}")
        lf.filter(pl.col("year") == year).sink_parquet(out_path)
        return

    for month_anchor, period_id in period_specs:
        if agg_level == "monthly":
            # Filter rows for this month
            m = month_anchor
            fname = make_output_filename(prefix, year, agg_level, period_id)
            out_path = output_dir / fname
            print(f"   Writing monthly file: {out_path}")

            lf.filter((pl.col("year") == year) & (pl.col("month") == m)) \
              .sink_parquet(out_path)

        elif agg_level == "half-year":
            # A1: Jan–Jun, A2: Jul–Dec
            if period_id == "A1":
                cond = (pl.col("month") >= 1) & (pl.col("month") <= 6)
            elif period_id == "A2":
                cond = (pl.col("month") >= 7) & (pl.col("month") <= 12)
            else:
                raise ValueError(f"Unknown half-year period_id: {period_id}")

            fname = make_output_filename(prefix, year, agg_level, period_id)
            out_path = output_dir / fname
            print(f"   Writing half-year file: {out_path}")
            lf.filter((pl.col("year") == year) & cond).sink_parquet(out_path)

        elif agg_level == "quarterly":
            # Q1: Jan–Mar, Q2: Apr–Jun, Q3: Jul–Sep, Q4: Oct–Dec
            if period_id == "Q1":
                cond = (pl.col("month") >= 1) & (pl.col("month") <= 3)
            elif period_id == "Q2":
                cond = (pl.col("month") >= 4) & (pl.col("month") <= 6)
            elif period_id == "Q3":
                cond = (pl.col("month") >= 7) & (pl.col("month") <= 9)
            elif period_id == "Q4":
                cond = (pl.col("month") >= 10) & (pl.col("month") <= 12)
            else:
                raise ValueError(f"Unknown quarterly period_id: {period_id}")

            fname = make_output_filename(prefix, year, agg_level, period_id)
            out_path = output_dir / fname
            print(f"   Writing quarterly file: {out_path}")
            lf.filter((pl.col("year") == year) & cond).sink_parquet(out_path)

        else:
            raise ValueError(f"Unsupported aggregation level: {agg_level}")


In [38]:
def process_year(
        dataset_key: str,
        files_by_year: Dict[int, List[Path]],
        year: int,
        agg_level: str,
        drop_cols: List[str],
        output_dir: Path
        ) -> None:
    """
    Process all files for a single year: load, combine, partition, and save.

    Parameters
    ----------
    dataset_key : str
        Dataset key, e.g. "era5-world_N37... (grib)".
    files_by_year : dict[int, list[pathlib.Path]]
        Mapping of year → list of file paths.
    year : int
        Year to process.
    agg_level : str
        Aggregation/split level: 'annual', 'half-year', 'quarterly', 'monthly'.
    drop_cols : list[str]
        Columns to drop.
    output_dir : pathlib.Path
        Output directory for parquet files.

    Returns
    -------
    None
    """
    prefix = dataset_key.split(" (")[0]  # strip " (ext)"
    print(f"\n=== Processing year {year} for dataset '{dataset_key}' ===")
    lf_year = combine_year_files(files_by_year, year)
    partition_and_save_year(lf_year, prefix, year, agg_level, drop_cols, output_dir)
    print(f"=== Finished year {year} ===\n")


In [39]:
def process_dataset(
        dataset_key: str,
        files_by_year: Dict[int, List[Path]],
        agg_level: str,
        drop_cols: List[str],
        output_dir: Path,
        max_workers: int = 4
        ) -> None:
    """
    Process an entire dataset across all available years with optional
    parallelisation (one worker per year).

    Parameters
    ----------
    dataset_key : str
        Dataset key, e.g. "era5-world_N37... (grib)".
    files_by_year : dict[int, list[pathlib.Path]]
        Mapping of year → list of file paths for the selected dataset.
    agg_level : str
        Aggregation/split level: 'annual', 'half-year', 'quarterly', 'monthly'.
    drop_cols : list[str]
        Columns to drop.
    output_dir : pathlib.Path
        Output directory for parquet files.
    max_workers : int, optional
        Maximum number of worker threads to use for parallel processing of years.

    Returns
    -------
    None
    """
    years = sorted(files_by_year.keys())
    if not years:
        print("No years found to process.")
        return

    print(f"\nStarting processing for dataset '{dataset_key}' with aggregation: {agg_level}")
    print(f"Years to process: {years}")
    print(f"Output directory: {output_dir.resolve()}\n")

    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = [
            executor.submit(
                process_year,
                dataset_key,
                files_by_year,
                year,
                agg_level,
                drop_cols,
                output_dir,
            )
            for year in years
        ]
        for fut in concurrent.futures.as_completed(futures):
            exc = fut.exception()
            if exc is not None:
                print(f"Error in worker: {exc}")


#### Main Execution

In [40]:
def main_interactive():
    """
    Full interactive workflow:

    1. Prompt for directory.
    2. Scan directory, show dataset summary.
    3. Prompt for dataset key.
    4. Scan variables from first file.
    5. Optionally verify structure across all files.
    6. Prompt for columns to drop (default: none).
    7. Prompt for aggregation level.
    8. Process dataset into time-split parquet files.
    """
    # 1. Directory selection
    raw_dir = prompt_user_for_directory(default_dir="data/raw")

    # 2. Scan & summary
    mapping = scan_and_optionally_print(raw_dir, print_summary=True)

    # 3. Dataset key selection
    selected_key = prompt_user_for_dataset_key(mapping)
    files_by_year = mapping[selected_key]

    # 4. Variable scan from first file
    vars_first = get_variables_from_first_file(files_by_year)
    print_variable_list(vars_first)

    # 5. Optional structure verification
    if prompt_yes_no("Would you like to verify that all files share the same structure?",
                     default="y"):
        verify_variables_across_files(files_by_year, vars_first)
    else:
        print("Skipping structure verification.\n")

    # 6. Column drop prompt (no defaults for now)
    drop_cols = prompt_for_columns_to_drop(default_drop=[])

    # 7. Aggregation level prompt
    agg_level = prompt_for_aggregation_level()

    # 8. Process dataset into parquet
    output_dir = Path("data/processed")
    # Simple heuristic: one worker per year, up to 8
    max_workers = min(len(files_by_year), 8) or 1

    process_dataset(
        dataset_key=selected_key,
        files_by_year=files_by_year,
        agg_level=agg_level,
        drop_cols=drop_cols,
        output_dir=output_dir,
        max_workers=max_workers,
    )

    print("\n✅ All done.\n")


#### Exec

In [41]:
main_interactive()


Please enter the relative directory you want to scan.
Press ENTER to use the default: [data/raw]
Example inputs: data/raw, data/new_data, datasets/era5, etc.

[ERROR] Directory 'data/raw' does not exist. Please try again.


Using directory: /Users/Daniel/Desktop/data/raw


Detected the following datasets in [../data/raw]:

----------------------------------------
era5-world_N37W68S6E98_d514a3a3c256 (grib)
----------------------------------------
	Years found: 2018, 2019, 2020, 2023, 2024, 2025
		2018: 01, 02
		2019: 08, 12
		2020: 01
		2023: 02, 12
		2024: 01, 03, 04
		2025: 04, 06

----------------------------------------
era5-world_N37W68S6E98_d514a3a3c256 (grib.5b7b6.idx)
----------------------------------------
	Years found: 2018
		2018: 01


Available dataset prefixes:
  1. era5-world_N37W68S6E98_d514a3a3c256 (grib)
  2. era5-world_N37W68S6E98_d514a3a3c256 (grib.5b7b6.idx)

You selected: era5-world_N37W68S6E98_d514a3a3c256 (grib)

Using first file for variable scan: ../data/raw/e

TypeError: '<' not supported between instances of 'dict' and 'dict'

#### Parsing ERA5 Filenames


In [None]:
import xarray as xr
from pathlib import Path

def scan_grib_variables_xarray(files: dict[int, list[Path]]) -> set[str]:
    """
    Scan GRIB files using xarray+cfgrib to extract unique variable shortNames.
    Does not load any full data arrays.
    """
    shortnames = set()

    for year, paths in files.items():
        for path in paths:
            try:
                # Open the GRIB file *without* selecting variables
                # Will fail intentionally, giving us available citations
                xr.open_dataset(
                    path,
                    engine="cfgrib",
                    backend_kwargs={"indexpath": ""}
                )
            except Exception as e:
                # cfgrib includes the available 'shortName' values in the exception message
                msg = str(e)
                if "shortName" in msg and "values" in msg:
                    # Example:
                    # "Found multiple values for key 'shortName': ['t2m', 'u10', 'v10']"
                    import re
                    found = re.findall(r"\['([^]]+)'\]", msg)
                    if found:
                        # Split comma-separated values inside the brackets
                        vars_split = found[0].replace("'", "").split(",")
                        vars_clean = [v.strip() for v in vars_split]
                        shortnames.update(vars_clean)

    return shortnames
