In [None]:
import os
from dataclasses import dataclass
from datetime import datetime, timedelta
from enum import Enum
from io import BytesIO, StringIO
import logging
from tempfile import NamedTemporaryFile
from typing import Dict, Optional, Tuple

import boto3
import numpy as np
import pandas as pd
import xarray as xr
from botocore import UNSIGNED
from botocore.config import Config
from numpy.typing import ArrayLike, NDArray

%load_ext pyinstrument

In [None]:
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger()


In [None]:
class Parameter(str, Enum):
    """Common weather Parameter Names"""

    TEMPERATURE = "temperature"
    RELATIVE_HUMIDITY = "relative_humidity"
    DEWPOINT_TEMPERATURE = "dewpoint_temperature"
    WIND_GUST = "wind_gust"
    WIND_U = "wind_u"
    WIND_V = "wind_v"
    WIND_SPEED = "wind_speed"
    WIND_DIRECTION = "wind_direction"


class Level(str, Enum):
    TWO_M_ABOVE_GROUND = "2_meters_above_ground"
    TEN_M_ABOVE_GROUND = "10_meters_above_ground"
    SURFACE = "surface"


@dataclass
class ModelParameter:
    parameter: Parameter
    level: Level
    cf_variable: str


desired_model_parameters = [
    ModelParameter(
        parameter=Parameter.TEMPERATURE,
        level=Level.TWO_M_ABOVE_GROUND,
        cf_variable="t2m",
    ),
    ModelParameter(
        parameter=Parameter.RELATIVE_HUMIDITY,
        level=Level.TWO_M_ABOVE_GROUND,
        cf_variable="r2",
    ),
    ModelParameter(
        parameter=Parameter.DEWPOINT_TEMPERATURE,
        level=Level.TWO_M_ABOVE_GROUND,
        cf_variable="d2m",
    ),
    ModelParameter(
        parameter=Parameter.WIND_GUST, level=Level.SURFACE, cf_variable="gust"
    ),
    ModelParameter(
        parameter=Parameter.WIND_U, level=Level.TEN_M_ABOVE_GROUND, cf_variable="u10"
    ),
    ModelParameter(
        parameter=Parameter.WIND_V, level=Level.TEN_M_ABOVE_GROUND, cf_variable="v10"
    ),
]

In [None]:
@dataclass
class ModelConfig:
    short_name: str
    model_hour_interval: int
    s3_bucket: str
    s3_prefix_pattern: str
    s3_grib_file_name_pattern: str
    s3_idx_file_name_pattern: str
    s3_complete_run_number_of_files: int
    parameters: Dict[str, str]
    levels: Dict[str, str]

    def get_s3_grib_key(self, run: datetime, forecast: int) -> str:
        return self.s3_prefix_pattern.format(
            run=run
        ) + self.s3_grib_file_name_pattern.format(run=run, forecast=forecast)

    def get_s3_idx_key(self, run: datetime, forecast: int) -> str:
        return self.s3_prefix_pattern.format(
            run=run
        ) + self.s3_idx_file_name_pattern.format(run=run, forecast=forecast)

    def get_grib_file_name(self, run: datetime, forecast: int) -> str:
        return self.s3_grib_file_name_pattern.format(run=run, forecast=forecast)

    def get_idx_file_name(self, run: datetime, forecast: int) -> str:
        return self.s3_idx_file_name_pattern.format(run=run, forecast=forecast)

    def get_s3_prefix(self, run: datetime) -> str:
        return self.s3_prefix_pattern.format(run=run)

    def get_model_parameter_name(self, parameter: Parameter) -> Optional[str]:
        return self.parameters.get(parameter)

    def get_model_level_name(self, level: Level) -> Optional[str]:
        return self.levels.get(level)


gfs_config = ModelConfig(
    short_name="gfs",
    model_hour_interval=6,
    s3_bucket="noaa-gfs-bdp-pds",
    s3_prefix_pattern="gfs.{run:%Y%m%d}/{run:%H}/atmos/",
    s3_grib_file_name_pattern="gfs.t{run:%H}z.pgrb2.0p25.f{forecast:03d}",
    s3_idx_file_name_pattern="gfs.t{run:%H}z.pgrb2.0p25.f{forecast:03d}.idx",
    s3_complete_run_number_of_files=418,
    parameters={
        Parameter.TEMPERATURE: "TMP",
        Parameter.RELATIVE_HUMIDITY: "RH",
        Parameter.DEWPOINT_TEMPERATURE: "DPT",
        Parameter.WIND_GUST: "GUST",
        Parameter.WIND_U: "UGRD",
        Parameter.WIND_V: "VGRD",
    },
    levels={
        Level.TWO_M_ABOVE_GROUND: "2 m above ground",
        Level.TEN_M_ABOVE_GROUND: "10 m above ground",
        Level.SURFACE: "surface",
    },
)

In [None]:
def get_latest_possible_model_run(model_hour_interval: int) -> datetime:
    """Gets the latest possible model run time based on the current
    utc time.

    Args:
        model_hour_interval (int): The interval between model runs in hours.
    Returns:
        datetime: Datetime object representing the latest possible model run
    """
    now = datetime.utcnow()
    current_hour = now.hour
    nearest_possible_model_hour = model_hour_interval * (
        current_hour // model_hour_interval
    )

    delta_hours = current_hour - nearest_possible_model_hour
    if delta_hours <= 3:
        delta_hours += 6

    nearest_model_time = now - timedelta(
        hours=delta_hours,
        minutes=now.minute,
        seconds=now.second,
        microseconds=now.microsecond,
    )

    return nearest_model_time


def get_latest_run(
    model_config: ModelConfig, max_runs_to_try: int = 3
) -> Optional[datetime]:
    """Attempts to find a complete GFS model run available on S3
    from a specified number of previous run times.

    Args:
        model_config (ModelConfig): The model configuration.
        max_runs_to_try (int, optional): The number of runs to check for
                                         completeness. Defaults to 3.

    Returns:
        Optional[datetime]: Datetime object representing the latest complete run.
    """
    expected_number_of_files = model_config.s3_complete_run_number_of_files
    model_hour_interval = model_config.model_hour_interval
    bucket = model_config.s3_bucket

    latest_possible_run = get_latest_possible_model_run(
        model_hour_interval=model_hour_interval
    )

    runs_to_try = [
        latest_possible_run - timedelta(hours=i * model_hour_interval)
        for i in range(max_runs_to_try)
    ]

    s3_client = boto3.client("s3", config=Config(signature_version=UNSIGNED))

    for run in runs_to_try:
        prefix = model_config.get_s3_grib_key(run=run, forecast=0).rstrip("0")
        logging.info(
            f"Checking {run:%Y%m%d %H}Z model run for completeness. (S3 prefix - s3://{bucket}/{prefix})"
        )
        result = s3_client.list_objects_v2(Bucket=bucket, Prefix=prefix)
        if contents := result.get("Contents"):
            num_files_found = len(contents)
            logging.info(
                f"Found {num_files_found} files available for {run:%Y%m%d %H}Z model run."
            )
            if num_files_found == expected_number_of_files:
                latest_run = run
                logging.info(f"Found complete run: {latest_run:%Y%m%d %HZ}")
                return latest_run

        else:
            logging.info(
                f"Found no files available for {run:%Y%m%d %H}Z model run. (S3 Prefix - {prefix})"
            )

    logging.warning(f"No complete runs found in previous {max_runs_to_try} runs.")

In [None]:
def get_test_idx_file(
    model_config: ModelConfig, run: datetime, forecast: int
) -> Optional[str]:
    """Gets the idx file for the given run and forecast provided it does not already exist in the current directory.

    run (datetime): The run time of the model.
    forecast (int): The forecast hour.

    returns (Optional[str]): The name of the idx file.
    """
    prefix = model_config.get_s3_prefix(run=run)
    bucket = model_config.s3_bucket
    idx_file = model_config.get_s3_idx_key(run=run, forecast=forecast)
    s3_client = boto3.client("s3", config=Config(signature_version=UNSIGNED))
    local_file_name = model_config.get_idx_file_name(run=run, forecast=forecast)

    if os.path.exists(local_file_name):
        logging.info(f"File {idx_file} already exists locally.")
        return local_file_name
    try:
        logging.info(f"Downloading {idx_file} from S3...")
        s3_client.download_file(bucket, idx_file, local_file_name)
    except s3_client.exceptions.NoSuchKey as e:
        logging.warning(f"File {idx_file} not found on S3. {e}")
    else:
        logging.info(f"Successfully downloaded {idx_file}.")

    return local_file_name

In [None]:
def parse_grib_index(index: str) -> Dict[str, Dict[str, Dict[str, Optional[int]]]]:
    """Parses a grib index file into a usable dictionary.
    Args:
        index (str): The contents of the grib index file.
    Returns:
        Dict[str, Dict[str, Dict[str, Optional[int]]]]: A dictionary containing parameter, level,
            and start/stop byte addresses.
    """
    result = {}
    prev_start = None
    for line in index.split("\n")[::-1]:
        if len(line) != 0:
            _, start, _, parameter, level, _ = line.rstrip(":").split(":")
            stop = prev_start - 1 if prev_start is not None else None
            byte_locations = {"start": int(start), "stop": stop}
            if parameter in result.keys():
                result[parameter][level] = byte_locations
            else:
                result[parameter] = {level: byte_locations}

            prev_start = int(start)
    return result

In [None]:
def get_data_from_s3(
    bucket: str,
    key: str,
    start_byte: Optional[int] = None,
    stop_byte: Optional[int] = None,
) -> bytes:
    """Downloads data (bytes) from s3 with an option to download a partial object by specifying start/stop bytes"""
    s3_client = boto3.client("s3", config=Config(signature_version=UNSIGNED))
    if start_byte is None or stop_byte is None:
        response = s3_client.get_object(Bucket=bucket, Key=key)
    else:
        range_header = {"Range": f"bytes={start_byte}-{stop_byte}"}
        response = s3_client.get_object(Bucket=bucket, Key=key, **range_header)
    return response["Body"].read()


def get_grib_message_from_s3(
    model_config: ModelConfig,
    grib_index: dict,
    run: datetime,
    forecast: int,
    parameter: Parameter,
    level: Level,
) -> Optional[bytes]:
    """Downloads a single grib message from s3 using the s3 url"""
    bucket = model_config.s3_bucket
    key = model_config.get_s3_grib_key(run=run, forecast=forecast)
    model_parameter_name = model_config.get_model_parameter_name(parameter)
    model_level_name = model_config.get_model_level_name(level)
    if model_parameter_name and model_level_name:
        byte_locations = grib_index[model_parameter_name][model_level_name]
        start_byte = byte_locations["start"]
        stop_byte = byte_locations["stop"]
        return get_data_from_s3(bucket, key, start_byte, stop_byte)
    else:
        return None


def get_grib_idx_from_s3(
    model_config: ModelConfig, run: datetime, forecast: int
) -> Optional[bytes]:
    """Downloads a grib index file from s3 using the s3 url"""
    bucket = model_config.s3_bucket
    key = model_config.get_s3_idx_key(run=run, forecast=forecast)
    return get_data_from_s3(bucket, key)

In [None]:
def convert_user_lon_to_gfs_lon(lon: float) -> float:
    """Converts a longitude to the longitude used by the GFS model.

    Args:
        lon (float): The longitude to convert.

    Returns:
        float: The longitude used by the GFS model.
    """
    if lon < 0:
        return 360 + lon
    return lon

In [None]:
def get_forecast_hour_for_valid_time(model_run: datetime, valid_time: datetime) -> int:
    """Gets the forecast hour for a given valid time.

    Args:
        model_run (datetime): The model run time.
        valid_time (datetime): The valid time.

    Returns:
        int: The forecast hour.
    """
    return int((valid_time - model_run).total_seconds() / 3600)


def test_get_forecast_hour_for_valid_time() -> None:
    model_run = datetime(2020, 1, 1, 0)
    valid_time = datetime(2020, 1, 1, 6)
    assert get_forecast_hour_for_valid_time(model_run, valid_time) == 6


def get_valid_time_for_forecast_hour(
    model_run: datetime, forecast_hour: int
) -> datetime:
    """Gets the valid time for a given forecast hour.

    Args:
        model_run (datetime): The model run time.
        forecast_hour (int): The forecast hour.

    Returns:
        datetime: The valid time.
    """
    return model_run + timedelta(hours=forecast_hour)


def test_get_valid_time_for_forecast_hour() -> None:
    model_run = datetime(2020, 1, 1, 0)
    forecast_hour = 6
    assert get_valid_time_for_forecast_hour(model_run, forecast_hour) == datetime(
        2020, 1, 1, 6
    )


test_get_forecast_hour_for_valid_time()
test_get_valid_time_for_forecast_hour()

In [None]:
@dataclass
class ParamResult:
    """The result of a parameter query."""

    parameter: str
    level: str
    latitude: float
    longitude: float
    value: float
    valid_time: datetime

In [None]:

def calculate_wind_speed_from_uv(u: float, v: float) -> float:
    """Calculates the wind speed from the U and V components.

    Args:
        u (float): The U component.
        v (float): The V component.

    Returns:
        float: The wind speed.
    """
    return np.sqrt(u**2 + v**2)


def calculate_wind_dir_from_uv(u: float, v: float) -> float:
    """Calculates the wind direction from the U and V components.

    Args:
        u (float): The U component.
        v (float): The V component.

    Returns:
        float: The wind direction.
    """
    return np.arctan2(u, v) * (180 / np.pi)


def calc_wind_speed_and_dir_from_uv(u: ArrayLike, v: ArrayLike):
    """Calculates the wind speed and direction from the U and V components.

    Args:
        u (ArrayLike): The U component.
        v (ArrayLike): The V component.

    Returns:
        Tuple[NDArray, NDArray]: The wind speed and direction.
    """
    speed = calculate_wind_speed_from_uv(u, v)
    dir = calculate_wind_dir_from_uv(u, v)
    return speed, dir


def create_wind_speed_and_direction_results(df) -> list[ParamResult]:
    """Creates the wind speed and direction results from the U and V components."""
    wind_u = df.value[df.parameter == Parameter.WIND_U].to_numpy()
    wind_v = df.value[df.parameter == Parameter.WIND_V].to_numpy()

    unique_locations = df[df.parameter == Parameter.WIND_U][["latitude", "longitude"]]
    latitudes = unique_locations.latitude.to_numpy()
    longitudes = unique_locations.longitude.to_numpy()
    valid_time = df.valid_time.unique()[0]
    logging.info("Calculating wind speed and direction for {} locations".format(len(unique_locations)))

    wind_speed, wind_dir = calc_wind_speed_and_dir_from_uv(wind_u, wind_v)

    level = Level.TEN_M_ABOVE_GROUND
    results = []
    for latitude, longitude, wind_speed, wind_dir in zip(
        latitudes, longitudes, wind_speed, wind_dir
    ):
        results.append(
            ParamResult(
                parameter=Parameter.WIND_SPEED.value,
                level=level.value,
                latitude=latitude,
                longitude=longitude,
                value=wind_speed,
                valid_time=valid_time,
            ),
        )
        results.append(
            ParamResult(
                parameter=Parameter.WIND_DIRECTION.value,
                level=level.value,
                latitude=latitude,
                longitude=longitude,
                value=wind_dir,
                valid_time=valid_time,
            )
        )
    
    return results

In [None]:
def main() -> None:
    desired_locations = [
        (41.10, -95.94),
        (45.0, -105.0),
        (40.0, 105.0),
    ]

    model_config = gfs_config

    model_run = get_latest_run(model_config=model_config)
    print(f"{model_run=}")

    valid_time = datetime.fromtimestamp((datetime.utcnow().timestamp() // 3600) * 3600)
    print(f"valid time: {valid_time}")
    forecast = get_forecast_hour_for_valid_time(
        model_run=model_run, valid_time=valid_time
    )
    print(f"{forecast=}")

    grib_index_data = get_grib_idx_from_s3(
        model_config=model_config, run=model_run, forecast=forecast
    )
    grib_index = parse_grib_index(grib_index_data.decode())

    results = []
    for param_config in desired_model_parameters:
        param = param_config.parameter
        level = param_config.level
        #level = model_config.get_model_level_name(param_config.level)
        cf_variable = param_config.cf_variable
        logging.info(f"Getting {param} at {level} for {len(desired_locations)} locations")

        grib_message = get_grib_message_from_s3(
            model_config=model_config,
            grib_index=grib_index,
            run=model_run,
            forecast=forecast,
            parameter=param,
            level=level,
        )
        assert grib_message[:4] == b"GRIB"

        with NamedTemporaryFile(mode="wb") as file:
            file.write(grib_message)

            ds = xr.open_dataset(file.name, engine="cfgrib")

            for latitude, longitude in desired_locations:
                gfs_longitude = convert_user_lon_to_gfs_lon(lon=longitude)
                value = (
                    ds[cf_variable]
                    .interp(latitude=latitude, longitude=gfs_longitude, method="linear")
                    .values
                    .item()
                )
                param_result = ParamResult(
                    parameter=param.value,
                    level=level.value,
                    latitude=latitude,
                    longitude=longitude,
                    value=value,
                    valid_time=valid_time,
                )
                results.append(param_result)

    df = pd.DataFrame(results)
    wind_df = pd.DataFrame(create_wind_speed_and_direction_results(df))
    return pd.concat([df, wind_df], ignore_index=True)


df = main()
df

In [None]:
%%pyinstrument
main()

In [None]:
if latest_run := get_latest_run(model_config=gfs_config):
    idx_test_file = get_test_idx_file(
        model_config=gfs_config, run=latest_run, forecast=0
    )


level = Level.TEN_M_ABOVE_GROUND
parameter = Parameter.WIND_U

test_idx_file = get_test_idx_file(model_config=gfs_config, run=latest_run, forecast=0)
with open(test_idx_file, mode="rb") as file:
    grib_index = parse_grib_index(file.read().decode())

grib_message = get_grib_message_from_s3(
    model_config=gfs_config,
    grib_index=grib_index,
    run=latest_run,
    forecast=0,
    parameter=parameter,
    level=level,
)

assert grib_message[:4] == b"GRIB"
if grib_message is not None:
    with open("test.grib", mode="wb") as file:
        file.write(grib_message)

ds = xr.open_dataset("test.grib", engine="cfgrib")

ds

In [None]:
ds.u10.interp(latitude=40.0, longitude=105.0, method="linear").values.item()

In [None]:
convert_user_lon_to_gfs_lon(lon=-95.93)