In [None]:
%reload_ext gswp.jupyter
from pathlib import Path

import numpy as np
import pandas as pd
import xarray as xr

from sklearn.naive_bayes import GaussianNB
import tensorflow as tf

from gswp.constants import STORE, PROBSEVERE as PS, GMGSI


print(
    [
        tf.config.experimental.get_device_details(gpu)
        for gpu in tf.config.list_physical_devices("GPU")
    ]
)

In [None]:
features = (
    PS.load()  # .to_dataframe()
    # .rename(columns={"MINX": "minx", "MINY": "miny", "MAXX": "maxx", "MAXY": "maxy"})
)
features

In [None]:
gmgsi = GMGSI.load()
gmgsi

In [None]:
from typing import Generic, NewType, TypeVar
from gswp.constants import MRMS_BOUNDS


W, E, S, N = MRMS_BOUNDS
_1D = NewType("1d", tuple)
_2D = NewType("2d", tuple)


def load_and_filter() -> pd.DataFrame:
    gmgsi = GMGSI.load()
    lat, lon = gmgsi["lat"].to_numpy(), gmgsi["lon"].to_numpy()
    lat = lat[(lat > S) & (lat < N)]
    lon = lon[(lon > W) & (lon < E)]
    return gmgsi.sel({"lat": lat, "lon": lon})


def min_diff(
    target: np.ndarray[_1D, np.floating], values: np.ndarray[_1D, np.floating]
) -> np.ndarray[_1D, np.unsignedinteger]:
    diff = abs(target[:, np.newaxis] - values)
    index = np.argmin(diff, axis=0)
    return index


gmgsi = load_and_filter()
lat, lon = (np.unique(gmgsi[crd]) for crd in ("lat", "lon"))
gmgsi

In [None]:
min_lon = min_diff(lon, features.minx.to_numpy())
max_lon = min_diff(lon, features.maxx.to_numpy())
min_lat = min_diff(lat, features.miny.to_numpy())
max_lat = min_diff(lat, features.maxy.to_numpy())

In [None]:
# NOTE:  NOT CORRECT!!! the times are not aligned
data = gmgsi.set_index(["lat", "lon"], append=True).unstack("lon")


def extract(arr):
    return [
        arr[..., x1:x2, y1:y2]
        for x1, x2, y1, y2 in np.c_[min_lat, max_lat, min_lon, max_lon]
    ]


arr = np.array(
    [
        data["GMGSI_LW"].to_numpy(),
        data["GMGSI_SW"].to_numpy(),
        data["GMGSI_WV"].to_numpy(),
    ]
)


features["observations"] = extract(arr)


features

In [None]:
np.array(
    [
        data["GMGSI_LW"].unstack("lon").to_numpy(),
        data["GMGSI_SW"].unstack("lon").to_numpy(),
        data["GMGSI_WV"].unstack("lon").to_numpy(),
    ]
).shape

In [None]:
import matplotlib.pyplot as plt

components = [
    "CAPE_M10M30",
    "EBSHEAR",
    "FLASH_DENSITY",
    "FLASH_RATE",
    "LJA",
    "MAXLLAZ",
    "MEANWIND_1-3kmAGL",
    "MESH",
    "MLCAPE",
    "MLCIN",
    "MUCAPE",
    "P98LLAZ",
    "P98MLAZ",
    "PS",
    "PWAT",
    # "SIZE",
    "SRH01KM",
    "VIL_DENSITY",
    "WETBULB_0C_HGT",
]

for label, s in (
    features[features["PS"] > 60]
    .sort_values("SIZE", ascending=False)
    .head(20)
    .iterrows()
):
    extent = s[["minx", "maxx", "miny", "maxy"]].to_list()
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(nrows=2, ncols=2, figsize=(18, 5))
    #
    storm_id = s["ID"]
    ax1.set_title(f"{storm_id=} longwave")
    ax1.imshow(s["longwave"], extent=extent)
    #
    ax2.set_title(f"{storm_id=} shortwave")
    ax2.imshow(s["shortwave"], extent=extent)
    #
    ax3.set_title(f"{storm_id=} shortwave")
    ax3.imshow(s["watervapor"], extent=extent)
    #
    s[components].plot.bar(ax=ax4)