In [1]:
%load_ext jupyter_black


# wxlab/common.py

In [2]:
import os
from glob import glob
from typing import NamedTuple

import pandas as pd
import numpy as np
from numpy.typing import NDArray


@pd.api.extensions.register_dataframe_accessor("geo")
class GeoAccessor:
    def __init__(self, dataframe: pd.DataFrame) -> None:

        self._index = dataframe.index

    @property
    def lat(self) -> NDArray[np.float32]:
        return self._index.unique("lat").to_numpy().astype(np.float32)

    @property
    def lon(self) -> NDArray[np.float32]:
        return self._index.unique("lon").to_numpy().astype(np.float32)


def unpack_files(ALL_GALWEM_FILES, ALL_PROBSEVERE_FILES, time_buffer: int = 90) -> tuple[pd.Series, pd.Series]:
    # GALWEM FILES
    galwem = pd.Series(ALL_GALWEM_FILES, name="GALWEM")

    g_times: pd.DataFrame = galwem.str.extract(r"FH.(?P<forecast_hour>\d{3})_DF__(?P<valid_time>\d{8})")
    galwem.index = pd.to_datetime(g_times["valid_time"]) + pd.to_timedelta(
        g_times["forecast_hour"].astype(int), unit="h"
    )
    # PROBSEVERE_FILES
    probsevere = pd.Series(ALL_PROBSEVERE_FILES, name="ProbSevere")
    probsevere.index = pd.to_datetime(probsevere.str.replace("_", "T").str.extract(r"(\d*T\d*).json")[0])

    buffer = pd.to_timedelta(time_buffer, unit="m")
    condition = (probsevere.index > galwem.index.min() - buffer) & (probsevere.index < galwem.index.max() + buffer)

    return galwem, probsevere[condition]


class BBox(NamedTuple):
    minx: float = -130.0 % 360
    maxx: float = -60.0 % 360
    miny: float = 20.0
    maxy: float = 55.0

In [3]:
import json
from typing import Iterable

from geopandas import GeoDataFrame, GeoSeries
from geojson import FeatureCollection


def to_dataframe(files: pd.Series):
    def generate() -> Iterable[GeoDataFrame]:
        """load function for probsevere dataset"""
        for file in files.tolist():
            with open(file, mode="r", encoding="utf8") as fc:
                feat = json.load(fc)
                df = GeoDataFrame.from_features(feat["features"])
                df["validTime"] = feat["validTime"]
                yield df

    df = pd.concat(generate(), ignore_index=True)

    df["validTime"] = pd.to_datetime(df["validTime"], format="%Y%m%d_%H%M%S %Z", utc=True).astype(np.int64)

    df["AVG_BEAM_HGT"] = df["AVG_BEAM_HGT"].str.replace(r"[A-Za-z]", "", regex=True).apply(pd.eval)

    df[["MAXRC_EMISS", "MAXRC_ICECF"]] = (
        df[["MAXRC_EMISS", "MAXRC_ICECF"]]
        .stack()
        .str.extract(r"(?:\()([a-z]*)(?:\))")
        .replace({"weak": 1, "moderate": 2, "strong": 3})
        .fillna(0)
        .unstack(-1)
        .droplevel(0, axis=1)
    )

    return df.set_index(["validTime", "ID"])


def index_abs_argmin(
    ps_bounds: pd.DataFrame,  # (27070, 2)
    galwem_grid: NDArray[np.float32],  # (281,)
) -> NDArray[np.float32]:  # (27070, 2)
    """
    >>> ps_bounds:GeoDataFrames.geometry.bounds[["minx", "maxx","miny", "maxy"]]
    >>> galwem_grid: NDArray[np.float32]
    """

    # first shaped the probsevere and galwm so that have a common axis
    ps_shaped = ps_bounds.to_numpy()[:, np.newaxis]  # (27070, 1, 2)
    galwem_shaped = galwem_grid[:, np.newaxis]  # (141, 1)
    delta = abs(galwem_shaped - ps_shaped)  # (27070, 281, 2)
    # in the delta find the smallest diffrence here -- ^
    index_nearest = np.argmin(delta, axis=1)  # (27070, 2)
    # use the index position of the smallest diff to a grid point
    # to index the grid for a max and min
    return galwem_grid[index_nearest]  # (27070, 2)

In [4]:
import pandas as pd
import xarray as xr


def to_dataset(
    fp: pd.Series,
    bbox: BBox,
    mapping={
        "lv_ISBL0": "hPa",
        "lat_0": "lat",
        "lon_0": "lon",
        "TMP_P0_L100_GLL0": "temp",
        "UGRD_P0_L100_GLL0": "u_wind",
        "VGRD_P0_L100_GLL0": "v_wind",
    },
):
    def generate_dataset() -> xr.Dataset:
        for timestamp, file in fp.items():
            ds: xr.Dataset = xr.load_dataset(file, engine="pynio")

            yield ds.expand_dims({"validTime": [timestamp.value]}).rename(mapping)

    ds = xr.concat(generate_dataset(), dim="validTime")
    if isinstance(bbox, BBox):
        condition = (ds.lon >= bbox.minx) & (ds.lon <= bbox.maxx) & (ds.lat >= bbox.miny) & (ds.lat <= bbox.maxy)
        ds = ds.where(condition, drop=True)
    return ds

# wxlab/grib.py

# wxlab/common.py

In [5]:
ALL_GALWEM_FILES = sorted(glob(os.path.join("data", "galwem", "*.GR2")))
ALL_PROBSEVERE_FILES = sorted(glob(os.path.join("data", "probsevere", "*.json")))
glwm, probsevere = unpack_files(ALL_GALWEM_FILES, ALL_PROBSEVERE_FILES)
df = to_dataset(glwm, bbox=BBox()).to_dataframe()
# ds =
df.columns.set_names("elements", inplace=True)

GALWEM = df.unstack("hPa").reorder_levels(["validTime", "lat", "lon"]).reorder_levels(["hPa", "elements"], axis=1)
GALWEM

Unnamed: 0_level_0,Unnamed: 1_level_0,hPa,3000.0,5000.0,7000.0,10000.0,15000.0,20000.0,25000.0,30000.0,40000.0,50000.0,...,10000.0,15000.0,20000.0,25000.0,30000.0,40000.0,50000.0,70000.0,85000.0,100000.0
Unnamed: 0_level_1,Unnamed: 1_level_1,elements,temp,temp,temp,temp,temp,temp,temp,temp,temp,temp,...,v_wind,v_wind,v_wind,v_wind,v_wind,v_wind,v_wind,v_wind,v_wind,v_wind
validTime,lat,lon,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2,Unnamed: 22_level_2,Unnamed: 23_level_2
1653004800000000000,20.0,230.00,218.069992,207.849991,199.869995,202.089996,209.399994,217.589996,229.019989,239.429993,252.839996,264.070007,...,13.250000,19.359999,21.720005,22.680000,21.879995,10.370000,2.42,0.02,-3.280002,-5.17
1653004800000000000,20.0,230.25,218.129990,207.739990,199.839996,202.019989,209.389999,217.440002,229.080002,239.479996,253.039993,264.149994,...,13.090000,19.580000,22.160004,23.789999,21.439995,10.130000,2.67,0.37,-3.070002,-5.27
1653004800000000000,20.0,230.50,218.220001,207.629990,199.819992,201.959991,209.259995,217.289993,229.059998,239.519989,253.279999,264.119995,...,12.840000,20.020000,22.320004,24.529999,20.999994,9.849999,2.82,0.31,-3.020002,-5.41
1653004800000000000,20.0,230.75,218.199997,207.529999,199.879990,201.899994,209.069992,217.139999,228.929993,239.619995,253.509995,264.089996,...,12.599999,20.459999,22.310005,25.070000,20.659994,9.429999,2.85,-0.11,-3.030002,-5.55
1653004800000000000,20.0,231.00,218.059998,207.440002,199.879990,201.869995,208.889999,217.259995,228.860001,239.669998,253.569992,264.209991,...,12.360000,20.740000,22.300005,25.519999,20.219995,8.920000,3.03,-0.62,-3.040002,-5.49
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1653026400000000000,55.0,299.00,224.259995,221.699997,221.879990,221.019989,226.190002,226.529999,224.220001,223.690002,236.720001,248.189987,...,3.380000,4.000000,5.960000,5.550000,2.030000,-3.230000,-1.01,-2.29,-2.350000,1.21
1653026400000000000,55.0,299.25,224.299988,221.720001,221.750000,221.099991,226.000000,226.589996,224.129990,223.940002,236.739990,248.139999,...,3.500000,4.050000,6.030000,5.650000,4.220000,-2.010000,0.11,-1.85,-2.890000,2.13
1653026400000000000,55.0,299.50,224.349991,221.720001,221.639999,221.139999,225.839996,226.549988,224.049988,224.139999,236.759995,248.080002,...,3.590000,4.130000,6.070000,5.750000,6.640000,-0.350000,1.13,-1.87,-3.650000,3.11
1653026400000000000,55.0,299.75,224.449997,221.709991,221.559998,221.220001,225.669998,226.549988,223.899994,224.159988,236.860001,248.110001,...,3.640000,4.220000,6.090000,5.930000,8.990000,1.460000,1.52,-1.97,-4.550000,3.10


In [6]:
# def itertime(
#     source: pd.Index,
#     target: np.ndarray,
# ) -> Iterable[tuple[pd.Timestamp, pd.Index]]:
#     time_interval = len(target)
#     start_time = np.argmin(
#         abs(target[:, np.newaxis] - source.values) > pd.to_timedelta(time_interval, unit="h"),
#         axis=1,
#     )

#     end_time = np.roll(start_time, -1)
#     end_time[-1] = -1

#     for timestamp, tuple_slice in zip(target, zip(start_time, end_time)):
#         yield timestamp, source[slice(*tuple_slice)]


def shape_like(ps: pd.DataFrame, target: pd.DataFrame) -> pd.DataFrame:
    ps = ps.copy()
    ps[["WEST", "EAST"]] = index_abs_argmin(ps.bounds[["minx", "maxx"]] % 360, target.geo.lon)
    # and S,N
    ps[["SOUTH", "NORTH"]] = index_abs_argmin(ps.bounds[["miny", "maxy"]], target.geo.lat)
    return ps.drop(["geometry"], axis=1).set_index(["WEST", "EAST", "NORTH", "SOUTH"], append=True)


ps = shape_like(to_dataframe(probsevere), GALWEM)
ps
# if __name__ == "__main__":
#     ALL_GALWEM_FILES = sorted(glob(os.path.join("data", "galwem", "*.GR2")))
#     ALL_PROBSEVERE_FILES = sorted(glob(os.path.join("data", "probsevere", "*.json")))
#     glwm, probsevere = unpack_files(ALL_GALWEM_FILES, ALL_PROBSEVERE_FILES)

#     ps = shape_like(to_dataframe(probsevere), GALWEM)

# ps

# GALWEM = ds.unstack("hPa").reorder_levels(["validTime", "lat", "lon"]).reorder_levels(["hPa", "elements"], axis=1)

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,MUCAPE,MLCAPE,MLCIN,EBSHEAR,SRH01KM,MEANWIND_1-3kmAGL,MESH,VIL_DENSITY,FLASH_RATE,FLASH_DENSITY,...,MAXRC_ICECF,WETBULB_0C_HGT,PWAT,CAPE_M10M30,LJA,SIZE,AVG_BEAM_HGT,MOTION_EAST,MOTION_SOUTH,PS
validTime,ID,WEST,EAST,NORTH,SOUTH,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1
1652999439000000000,25310,269.75,270.00,38.75,38.25,4587,2950,-13,38.6,86,32.1,0.52,1.58,71,0.76,...,0.0,11.9,2.0,816,1.1,842,3.277778,9.566,-9.504,75
1652999439000000000,25384,274.25,274.50,38.00,38.00,3571,3490,-24,42.8,194,19.5,0.48,1.57,28,0.66,...,0.0,11.5,1.8,859,0.8,270,3.242424,3.77,-1.666,48
1652999439000000000,25496,269.50,269.50,45.00,45.00,182,10,0,48.0,428,43.3,0.66,2.24,13,0.46,...,1.0,9.3,1.3,14,0.6,226,3.280851,16.893,-3.447,28
1652999439000000000,25504,273.25,273.50,39.00,39.00,2791,1740,-85,39.1,110,18.6,1.09,2.64,35,0.84,...,0.0,10.4,1.7,652,1.0,328,3.293478,13.776,0.591,89
1652999439000000000,25505,272.75,273.00,38.75,38.75,2954,2227,-51,40.0,172,17.2,1.65,2.91,39,0.92,...,0.0,11.5,1.7,843,2.0,241,3.270000,13.621,-1.49,98
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1653031715000000000,30546,282.75,282.75,31.25,31.25,3506,1202,-110,25.1,63,9.7,0.40,0.51,0,0.00,...,0.0,12.4,1.2,815,0.0,72,3.278947,10.478,10.213,5
1653031715000000000,30547,279.75,280.00,26.75,26.50,3080,2349,-1,25.2,25,7.7,0.13,0.75,0,0.00,...,0.0,12.4,2.0,683,0.0,48,3.268750,4.937,-4.723,2
1653031715000000000,30548,280.75,280.75,26.25,26.25,3530,2699,-999,30.8,34,7.9,0.00,0.66,0,0.03,...,0.0,12.4,1.8,705,0.0,45,3.283898,10.459,10.169,2
1653031715000000000,30549,281.00,281.00,26.25,26.25,3497,2694,-999,31.6,27,8.3,0.00,0.40,0,0.00,...,0.0,12.4,1.8,691,0.0,62,3.279070,10.478,10.213,3


In [7]:
target = GALWEM.copy()  # .index.unique("validTime")
source = ps.copy().droplevel("ID")  # .droplevel("ID")  # .index.unique("validTime")


def align_time(probsevere: pd.DataFrame, forecast: pd.DataFrame) -> pd.DataFrame:
    """aligns the time in the probsevere forecast"""

    def sync_time(stack: NDArray[np.int64], source_time: pd.Index) -> NDArray[np.int64]:
        # stack = target_time.to_numpy()[:, np.newaxis]
        delta = abs(stack - source_time.to_numpy()).astype("timedelta64[ns]")
        condition = delta < pd.to_timedelta(3, unit="h")
        value = np.where(condition, source_time, np.nan).astype(np.int64)
        return np.nanmax(value, axis=1)

    df = probsevere.reset_index("validTime")

    df["validTime"] = sync_time(
        df["validTime"].to_numpy()[:, np.newaxis],
        forecast.index.unique("validTime"),
    )

    return df.set_index("validTime", append=True).reorder_levels(["validTime", "WEST", "EAST", "NORTH", "SOUTH"])


df = align_time(
    ps.droplevel("ID"),
    GALWEM,
)
df

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,MUCAPE,MLCAPE,MLCIN,EBSHEAR,SRH01KM,MEANWIND_1-3kmAGL,MESH,VIL_DENSITY,FLASH_RATE,FLASH_DENSITY,...,MAXRC_ICECF,WETBULB_0C_HGT,PWAT,CAPE_M10M30,LJA,SIZE,AVG_BEAM_HGT,MOTION_EAST,MOTION_SOUTH,PS
validTime,WEST,EAST,NORTH,SOUTH,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1
1653004800000000000,269.75,270.00,38.75,38.25,4587,2950,-13,38.6,86,32.1,0.52,1.58,71,0.76,...,0.0,11.9,2.0,816,1.1,842,3.277778,9.566,-9.504,75
1653004800000000000,274.25,274.50,38.00,38.00,3571,3490,-24,42.8,194,19.5,0.48,1.57,28,0.66,...,0.0,11.5,1.8,859,0.8,270,3.242424,3.77,-1.666,48
1653004800000000000,269.50,269.50,45.00,45.00,182,10,0,48.0,428,43.3,0.66,2.24,13,0.46,...,1.0,9.3,1.3,14,0.6,226,3.280851,16.893,-3.447,28
1653004800000000000,273.25,273.50,39.00,39.00,2791,1740,-85,39.1,110,18.6,1.09,2.64,35,0.84,...,0.0,10.4,1.7,652,1.0,328,3.293478,13.776,0.591,89
1653004800000000000,272.75,273.00,38.75,38.75,2954,2227,-51,40.0,172,17.2,1.65,2.91,39,0.92,...,0.0,11.5,1.7,843,2.0,241,3.270000,13.621,-1.49,98
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1653026400000000000,282.75,282.75,31.25,31.25,3506,1202,-110,25.1,63,9.7,0.40,0.51,0,0.00,...,0.0,12.4,1.2,815,0.0,72,3.278947,10.478,10.213,5
1653026400000000000,279.75,280.00,26.75,26.50,3080,2349,-1,25.2,25,7.7,0.13,0.75,0,0.00,...,0.0,12.4,2.0,683,0.0,48,3.268750,4.937,-4.723,2
1653026400000000000,280.75,280.75,26.25,26.25,3530,2699,-999,30.8,34,7.9,0.00,0.66,0,0.03,...,0.0,12.4,1.8,705,0.0,45,3.283898,10.459,10.169,2
1653026400000000000,281.00,281.00,26.25,26.25,3497,2694,-999,31.6,27,8.3,0.00,0.40,0,0.00,...,0.0,12.4,1.8,691,0.0,62,3.279070,10.478,10.213,3


In [18]:
df2 = df.copy()


def probsevere_meshgrid(df: pd.DataFrame) -> pd.DataFrame:

    west, east, north, south = (df.index.get_level_values(name) for name in ["WEST", "EAST", "NORTH", "SOUTH"])

    condition = (north == south) & (west == east)
    # mesh_grid_condition
    df.loc[condition] = df.loc[condition].reset_index().groupby(["validTime", "WEST", "EAST", "NORTH", "SOUTH"]).mean()
    return df.droplevel(["WEST", "SOUTH"]).rename_axis(["validTime", "lat", "lon"])


df3 = probsevere_meshgrid(df2.astype(np.float32))
df4 = probsevere_meshgrid(df2.astype(np.float64))
(df3 == df4).all()
df3.AVG_BEAM_HGT, df4.AVG_BEAM_HGT

(validTime            lat     lon  
 1653004800000000000  270.00  38.75    3.277778
                      274.50  38.00    3.242424
                      269.50  45.00    3.278357
                      273.50  39.00    3.293478
                      273.00  38.75    3.270000
                                         ...   
 1653026400000000000  282.75  31.25    3.280157
                      280.00  26.75    3.268750
                      280.75  26.25    3.281434
                      281.00  26.25    3.279070
                      278.25  26.00    3.277778
 Name: AVG_BEAM_HGT, Length: 27070, dtype: float32,
 validTime            lat     lon  
 1653004800000000000  270.00  38.75    3.277778
                      274.50  38.00    3.242424
                      269.50  45.00    3.278357
                      273.50  39.00    3.293478
                      273.00  38.75    3.270000
                                         ...   
 1653026400000000000  282.75  31.25    3.280157
            

In [9]:
target = GALWEM.copy()  # .index.unique("validTime")
source = ps.copy().droplevel("ID")  # .index.unique("validTime")
interval = 3
# print(source)


def doit(ps_time_index, time_index):
    ...
    # print(time_slice, ps_time_index)
    frame = source.loc[time_index, :].groupby(["WEST", "EAST", "NORTH", "SOUTH"]).mean()
    # print(frame)

    index_names = frame.index.names

    west, east, north, south = (frame.index.get_level_values(name) for name in ["WEST", "EAST", "NORTH", "SOUTH"])

    mesh_grid_condition = (north == south) & (west == east)
    # NOTE NEEDS VALIDATION
    frame.loc[mesh_grid_condition] = frame.loc[mesh_grid_condition].groupby(["WEST", "EAST", "NORTH", "SOUTH"]).mean()
    # print(frame)
    frame = frame.droplevel(["WEST", "SOUTH"])

    frame.index = frame.index.set_names(["lon", "lat"])

    # print(type(ps_time_index))
    frame["validTime"] = ps_time_index

    return frame.reset_index().set_index(["validTime", "lat", "lon"], append=True)


def itertime3(targetT, sourceT):
    # target = GALWEM.index.unique("validTime")
    # source = ps.index.unique("validTime")
    interval = 3
    start_time = np.argmax(
        np.abs(targetT.values[:, np.newaxis] - sourceT.values).astype("timedelta64[ns]")
        < pd.to_timedelta(interval, unit="h"),
        axis=1,
    )
    end_time = np.roll(start_time, -1)
    end_time[-1] = -1
    print(len())
    return zip(targetT, zip(start_time, end_time))


FINAL = pd.concat(
    (
        doit(x, source.index[slice(*y)])
        for x, y in itertime3(GALWEM.index.unique("validTime"), ps.index.unique("validTime"))
    ),
    axis=0,
)

TypeError: len() takes exactly one argument (0 given)