In [None]:
import logging
import os
from dataclasses import dataclass
from datetime import datetime, timedelta
from io import BytesIO, StringIO
from logging import info, warning
from typing import Dict, Optional

import boto3
from botocore import UNSIGNED
from botocore.config import Config

In [None]:
logger = logging.getLogger()
logger.setLevel(logging.INFO)

In [None]:
@dataclass
class ModelConfig:
    short_name: str
    model_hour_interval: int
    s3_bucket: str
    s3_prefix_pattern: str
    s3_grib_file_name_pattern: str
    s3_idx_file_name_pattern: str
    s3_complete_run_number_of_files: int
    parameters: Dict[str, str]

    def get_s3_grib_key(self, run: datetime, forecast: int) -> str:
        return self.s3_prefix_pattern.format(
            run=run
        ) + self.s3_grib_file_name_pattern.format(run=run, forecast=forecast)

    def get_s3_idx_key(self, run: datetime, forecast: int) -> str:
        return self.s3_prefix_pattern.format(
            run=run
        ) + self.s3_idx_file_name_pattern.format(run=run, forecast=forecast)

    def get_grib_file_name(self, run: datetime, forecast: int) -> str:
        return self.s3_grib_file_name_pattern.format(run=run, forecast=forecast)

    def get_idx_file_name(self, run: datetime, forecast: int) -> str:
        return self.s3_idx_file_name_pattern.format(run=run, forecast=forecast)

    def get_s3_prefix(self, run: datetime) -> str:
        return self.s3_prefix_pattern.format(run=run)


gfs_config = ModelConfig(
    short_name="gfs",
    model_hour_interval=6,
    s3_bucket="noaa-gfs-bdp-pds",
    s3_prefix_pattern="gfs.{run:%Y%m%d}/{run:%H}/atmos/",
    s3_grib_file_name_pattern="gfs.t{run:%H}z.pgrb2.0p25.f{forecast:03d}",
    s3_idx_file_name_pattern="gfs.t{run:%H}z.pgrb2.0p25.f{forecast:03d}.idx",
    s3_complete_run_number_of_files=418,
    parameters={},
)

In [None]:
def get_latest_possible_model_run(model_hour_interval: int) -> datetime:
    """Gets the latest possible model run time based on the current
    utc time.

    Args:
        model_hour_interval (int): The interval between model runs in hours.
    Returns:
        datetime: Datetime object representing the latest possible model run
    """
    now = datetime.utcnow()
    current_hour = now.hour
    nearest_possible_model_hour = model_hour_interval * (
        current_hour // model_hour_interval
    )

    delta_hours = current_hour - nearest_possible_model_hour
    if delta_hours <= 3:
        delta_hours += 6

    nearest_model_time = now - timedelta(
        hours=delta_hours,
        minutes=now.minute,
        seconds=now.second,
        microseconds=now.microsecond,
    )

    return nearest_model_time


def get_latest_run(
    model_config: ModelConfig, max_runs_to_try: int = 3
) -> Optional[datetime]:
    """Attempts to find a complete GFS model run available on S3
    from a specified number of previous run times.

    Args:
        model_config (ModelConfig): The model configuration.
        max_runs_to_try (int, optional): The number of runs to check for
                                         completeness. Defaults to 3.

    Returns:
        Optional[datetime]: Datetime object representing the latest complete run.
    """
    expected_number_of_files = model_config.s3_complete_run_number_of_files
    model_hour_interval = model_config.model_hour_interval
    bucket = model_config.s3_bucket

    latest_possible_run = get_latest_possible_model_run(
        model_hour_interval=model_hour_interval
    )

    runs_to_try = [
        latest_possible_run - timedelta(hours=i * model_hour_interval)
        for i in range(max_runs_to_try)
    ]

    s3_client = boto3.client("s3", config=Config(signature_version=UNSIGNED))

    for run in runs_to_try:
        prefix = model_config.get_s3_grib_key(run=run, forecast=0).rstrip("0")
        info(
            f"Checking {run:%Y%m%d %H}Z model run for completeness. (S3 prefix - s3://{bucket}/{prefix})"
        )
        result = s3_client.list_objects_v2(Bucket=bucket, Prefix=prefix)
        if contents := result.get("Contents"):
            num_files_found = len(contents)
            info(
                f"Found {num_files_found} files available for {run:%Y%m%d %H}Z model run."
            )
            if num_files_found == expected_number_of_files:
                latest_run = run
                info(f"Found complete run: {latest_run:%Y%m%d %HZ}")
                return latest_run

        else:
            info(
                f"Found no files available for {run:%Y%m%d %H}Z model run. (S3 Prefix - {prefix})"
            )

    warning(f"No complete runs found in previous {max_runs_to_try} runs.")

In [None]:
def get_test_idx_file(
    model_config: ModelConfig, run: datetime, forecast: int
) -> Optional[str]:
    """Gets the idx file for the given run and forecast provided it does not already exist in the current directory.

    run (datetime): The run time of the model.
    forecast (int): The forecast hour.

    returns (Optional[str]): The name of the idx file.
    """
    prefix = model_config.get_s3_prefix(run=run)
    bucket = model_config.s3_bucket
    idx_file = model_config.get_s3_idx_key(run=run, forecast=forecast)
    s3_client = boto3.client("s3", config=Config(signature_version=UNSIGNED))
    local_file_name = model_config.get_idx_file_name(run=run, forecast=forecast)

    if os.path.exists(local_file_name):
        info(f"File {idx_file} already exists locally.")
        return local_file_name
    try:
        info(f"Downloading {idx_file} from S3...")
        s3_client.download_file(bucket, idx_file, local_file_name)
    except s3_client.exceptions.NoSuchKey as e:
        warning(f"File {idx_file} not found on S3. {e}")
    else:
        info(f"Successfully downloaded {idx_file}.")

    return local_file_name

In [None]:
def parse_grib_index(index: StringIO) -> Dict[str, Dict[str, Dict[str, Optional[int]]]]:
    """Parses a grib index file into a usable dictionary.
    Args:
        index (str): The contents of the grib index file.
    Returns:
        Dict[str, Dict[str, Dict[str, Optional[int]]]]: A dictionary containing parameter, level,
            and start/stop byte addresses.
    """
    result = {}
    prev_start = None
    index.seek(0)
    for line in index.read().split("\n")[::-1]:
        if len(line) != 0:
            _, start, _, parameter, level, _ = line.rstrip(":").split(":")
            stop = prev_start - 1 if prev_start is not None else None
            byte_locations = {"start": int(start), "stop": stop}
            if parameter in result.keys():
                result[parameter][level] = byte_locations
            else:
                result[parameter] = {level: byte_locations}

            prev_start = int(start)
    return result

In [None]:
def get_grib_message_from_s3(
    bucket: str, key: str, start_byte: int, stop_byte: int
) -> bytes:
    """Downloads a single grib message from s3 using the s3 url and start/stop bytes"""
    s3_client = boto3.client("s3", config=Config(signature_version=UNSIGNED))
    range_header = {"Range": f"bytes={start_byte}-{stop_byte}"}
    response = s3_client.get_object(Bucket=bucket, Key=key, **range_header)
    return response["Body"].read()

In [None]:
if latest_run := get_latest_run(model_config=gfs_config):
    idx_test_file = get_test_idx_file(
        model_config=gfs_config, run=latest_run, forecast=0
    )

In [None]:
grib_key = gfs_config.get_s3_grib_key(run=latest_run, forecast=0)
print(grib_key)
if idx_test_file := get_test_idx_file(
    model_config=gfs_config, run=latest_run, forecast=0
):
    with open(idx_test_file, "r") as f:
        grib_index = parse_grib_index(index=f)

bucket = gfs_config.s3_bucket
parameter = grib_index.get("TMP")
level = parameter.get("2 m above ground")
start_byte = level.get("start")
stop_byte = level.get("stop")
grib_message = get_grib_message_from_s3(
    bucket=bucket, key=grib_key, start_byte=start_byte, stop_byte=stop_byte
)
assert grib_message[:4] == b"GRIB"

In [None]:
from tempfile import NamedTemporaryFile

import xarray as xr

lat = 40.0
lon = 105.0
if grib_message is not None:
    with NamedTemporaryFile(mode="wb") as file:
        file.write(grib_message)
        ds = xr.open_dataset(file.name, engine="cfgrib")
        print(ds.coords)
        print(ds.t2m.interp(latitude=lat, longitude=lon, method="linear").values)