In [None]:
import datetime
import json
import time
from enum import Enum
from typing import Tuple

import fsspec
import matplotlib.pyplot as plt
import numpy as np
import omfiles
from s3fs import S3FileSystem


class SupportedDomain(Enum):
    dwd_icon_d2 = "dwd_icon_d2"
    ecmwf_ifs025 = "ecmwf_ifs025"

    def file_length(self):
        """Return the number of timesteps in a single chunk file"""
        if self == SupportedDomain.dwd_icon_d2:
            return 121
        elif self == SupportedDomain.ecmwf_ifs025:
            return 104
        else:
            raise ValueError(f"Unsupported domain {self}")

    def lat_lon_grid(self) -> Tuple[np.ndarray, np.ndarray]:
        """Return the latitude and longitude arrays for the domain"""
        if self == SupportedDomain.dwd_icon_d2:
            # DWD ICON D2 is regularized during download to nx: 1215, ny: 721 points
            lat_start = 43.18
            lat_step_size = 0.02
            lat_steps = 746
            lon_start = -3.94
            lon_step_size = 0.02
            lon_steps = 1215
            lat = np.linspace(lat_start, lat_start + lat_step_size * lat_steps, lat_steps, endpoint=False)
            lon = np.linspace(lon_start, lon_start + lon_step_size * lon_steps, lon_steps, endpoint=False)
            return lat, lon
        elif self == SupportedDomain.ecmwf_ifs025:
            # ECMWF IFS grid is a regular global lat/lon grid
            lat = np.linspace(-90, 90, 721, endpoint=True)
            lon = np.linspace(-180, 180, 1440, endpoint=False)
            return lat, lon
        else:
            raise ValueError(f"Unsupported domain {self}")


class SupportedVariable(Enum):
    temperature_2m = "temperature_2m"


def find_chunk_for_timestamp(target_time: datetime.datetime, domain: SupportedDomain) -> Tuple[int, np.ndarray]:
    """
    Find the chunk number that contains a specific timestamp.

    Args:
        target_time: The timestamp to find
        domain: The domain to search in

    Returns:
        Tuple containing the chunk number and the time range of the chunk
    """
    meta_file = f"openmeteo/data/{domain.value}/static/meta.json"
    # Load metadata from S3
    fs = fsspec.filesystem(protocol="s3", anon=True)
    with fs.open(meta_file, mode="r") as f:
        metadata = json.load(f)

    # Get domain-specific parameters
    dt_seconds = metadata["temporal_resolution_seconds"]
    om_file_length = domain.file_length()

    # Calculate seconds since epoch for the target time
    epoch = datetime.datetime(1970, 1, 1)
    target_seconds = int((target_time - epoch).total_seconds())

    # Calculate the chunk number
    chunk = target_seconds // (om_file_length * dt_seconds)

    # Calculate the timerange for the chunk
    chunk_start = np.datetime64(epoch + datetime.timedelta(0, chunk * om_file_length * dt_seconds))
    chunk_end = np.datetime64(epoch + datetime.timedelta(0, (chunk + 1) * om_file_length * dt_seconds))
    print(f"Chunk {chunk} covers the timerange from {chunk_start} to {chunk_end}")
    dt_range = np.arange(
        chunk_start, chunk_end, np.timedelta64(datetime.timedelta(0, dt_seconds)), dtype="datetime64[s]"
    )

    return chunk, dt_range

In [None]:
async def main():
    # Setup parameters
    domain = SupportedDomain.dwd_icon_d2
    variable = SupportedVariable.temperature_2m
    timestamp = datetime.datetime(2025, 2, 1, 12, 0)

    # Get domain information
    chunk_num, timerange = find_chunk_for_timestamp(timestamp, domain)
    lat, lon = domain.lat_lon_grid()
    s3_file = f"openmeteo/data/{domain.value}/{variable.value}/chunk_{chunk_num}.om"

    print(f"Accessing file: {s3_file}")
    print(f"Grid dimensions: lat={len(lat)}, lon={len(lon)}")

    # Define a region of interest (subset of data to extract)
    start_lat_idx, end_lat_idx = 200, 300  # Example slice
    start_lon_idx, end_lon_idx = 400, 500  # Example slice
    time_idx = 0  # First timestamp in the chunk

    start_time = time.time()

    # Configure S3 filesystem with mmap cache
    fs_async = S3FileSystem(anon=True, default_block_size=256, default_cache_type="none")
    # fs_async = CachingFileSystem(fs=s3_fs, cache_check=3600, block_size=256, cache_storage="cache", check_files=False, same_names=True)

    # Initialize the concurrent reader
    reader_async = await omfiles.OmFilePyReaderAsync.from_fsspec(fs_async, s3_file)
    print(f"Data shape: {reader_async.shape}")

    # Extract data slice concurrently
    data = await reader_async.read_concurrent(
        (slice(start_lat_idx, end_lat_idx), slice(start_lon_idx, end_lon_idx), slice(time_idx, time_idx + 1))
    )
    # data = data[:,:,0]  # Remove time dimension

    elapsed_time = time.time() - start_time
    print(f"Data fetching time: {elapsed_time:.2f} seconds")

    # Extract the actual lat/lon values for this region
    region_lat = lat[start_lat_idx:end_lat_idx]
    region_lon = lon[start_lon_idx:end_lon_idx]

    # Create meshgrid for plotting
    lon_grid, lat_grid = np.meshgrid(region_lon, region_lat)

    # Visualize the data
    plt.figure(figsize=(10, 6))
    plt.contourf(lon_grid, lat_grid, data, cmap="RdBu_r", levels=20)
    plt.colorbar(label="Temperature (°C)")
    plt.title(f"Temperature at 2m - {timerange[time_idx]}")
    plt.xlabel("Longitude (°)")
    plt.ylabel("Latitude (°)")
    plt.grid(linestyle="--", alpha=0.5)
    plt.show()


# Run the async function
await main()