In [1]:
%reload_ext gswp.jupyter

In [2]:
import numpy as np
import pandas as pd
from gswp.api import extract_features

idx = pd.IndexSlice

In [3]:
def make_gridspace(
    x1: float, y1: float, x2: float, y2: float, x_size: int = 972, y_size: int = 635
) -> pd.DataFrame:

    x = np.linspace(x1, x2, x_size, dtype=np.float32)
    y = np.linspace(y1, y2, y_size, dtype=np.float32)

    grid = (
        pd.DataFrame(
            columns=pd.Index(x, name="lon"),
            index=pd.Index(y, name="lat"),
        )
        .unstack("lat")
        .reset_index()
        .dropna(axis=1)
    )
    grid["water_vapor"] = np.random.randint(0, 255, size=len(grid))
    grid["long_wave_ir"] = np.random.randint(0, 255, size=len(grid))
    return grid.set_index(["lat", "lon"]).sort_index(level=["lat", "lon"])


def make_features(
    x1: int, y1: int, x2: int, y2: int, n_features: int = 76_020
) -> pd.DataFrame:
    xx1 = np.random.randint(x1, x2, n_features)
    xx2 = np.random.randint(x1, x2, n_features)
    yy1 = np.random.randint(y2, y1, n_features)
    yy2 = np.random.randint(y2, y1, n_features)
    data = {
        "minx": np.min([xx1, xx2], axis=0),
        "maxx": np.max([xx1, xx2], axis=0),
        "miny": np.min([yy1, yy2], axis=0),
        "maxy": np.max([yy1, yy2], axis=0),
    }
    return pd.DataFrame(data)


floating1DArray = np.ndarray[tuple[int], np.floating]
unsignedinteger1DArray = np.ndarray[tuple[int], np.unsignedinteger]


def min_diff(
    target: floating1DArray, values: floating1DArray
) -> unsignedinteger1DArray:
    diff = abs(target[:, np.newaxis] - values)
    index = np.argmin(diff, axis=0)
    return index


# def main():
# setting up the gridspace
xy_1 = -129, 54  # west, north
xy_2 = -60, 20  # east, south
x_size, y_size = 972, 635
X1, Y1 = xy_1
X2, Y2 = xy_2
gs = make_gridspace(X1, Y1, X2, Y2, x_size=x_size, y_size=y_size)
lat, lon = (gs.index.unique(crd).to_numpy() for crd in ("lat", "lon"))
# FEATURES
n_features = 76_020
features = make_features(X1, Y1, X2, Y2, n_features=n_features)

min_lon = min_diff(lon, features.minx.to_numpy())
max_lon = min_diff(lon, features.maxx.to_numpy())
min_lat = min_diff(lat, features.miny.to_numpy())
max_lat = min_diff(lat, features.maxy.to_numpy())


In [4]:
%%time
arr = gs.unstack("lon")["water_vapor"].to_numpy(dtype=np.float32)
result = [
    arr[x1:x2, y1:y2] for x1, x2, y1, y2 in np.c_[min_lat, max_lat, min_lon, max_lon]
]
result

CPU times: user 179 ms, sys: 0 ns, total: 179 ms
Wall time: 177 ms


[array([[169.,  36., 227., ..., 163.,  89., 197.],
        [197.,  28., 226., ...,  43., 180.,  98.],
        [ 43., 160., 208., ..., 105., 157., 150.],
        ...,
        [120., 135.,  26., ..., 139.,  98.,  56.],
        [ 23., 117., 183., ..., 212., 183.,  93.],
        [209., 247.,  95., ...,   7., 144.,  84.]], dtype=float32),
 array([[241., 218.,   4., ..., 184., 159.,  69.],
        [ 13., 138., 240., ..., 222., 137., 217.],
        [ 97.,   1., 138., ..., 224.,  26., 254.],
        ...,
        [  8., 125., 174., ...,  23.,  90., 230.],
        [233., 201., 187., ..., 123.,  54.,  54.],
        [149.,   3.,  43., ...,  94.,   5., 109.]], dtype=float32),
 array([[104., 219.,  92., ...,   6., 189., 114.],
        [ 67., 208., 162., ..., 161., 207.,  49.],
        [ 52.,  71., 202., ..., 129.,  42., 134.],
        ...,
        [ 79.,   2.,  21., ..., 130., 147.,  86.],
        [161., 121., 112., ..., 115., 228.,   0.],
        [217., 162.,  31., ...,   0.,  69., 246.]], dtype=fl

In [5]:
%%time

extract_features(
    gs.unstack("lon")["water_vapor"].to_numpy(dtype=np.float32),
    min_lon,
    max_lon,
    min_lat,
    max_lat,
)
# extract_features(
#     gs.unstack("lon")["long_wave_ir"].to_numpy(dtype=np.float32),
#     min_lon,
#     max_lon,
#     min_lat,
#     max_lat,
# )

CPU times: user 170 ms, sys: 9.77 ms, total: 180 ms
Wall time: 183 ms


[array([[169.,  36., 227., ..., 163.,  89., 197.],
        [197.,  28., 226., ...,  43., 180.,  98.],
        [ 43., 160., 208., ..., 105., 157., 150.],
        ...,
        [120., 135.,  26., ..., 139.,  98.,  56.],
        [ 23., 117., 183., ..., 212., 183.,  93.],
        [209., 247.,  95., ...,   7., 144.,  84.]], dtype=float32),
 array([[241., 218.,   4., ..., 184., 159.,  69.],
        [ 13., 138., 240., ..., 222., 137., 217.],
        [ 97.,   1., 138., ..., 224.,  26., 254.],
        ...,
        [  8., 125., 174., ...,  23.,  90., 230.],
        [233., 201., 187., ..., 123.,  54.,  54.],
        [149.,   3.,  43., ...,  94.,   5., 109.]], dtype=float32),
 array([[104., 219.,  92., ...,   6., 189., 114.],
        [ 67., 208., 162., ..., 161., 207.,  49.],
        [ 52.,  71., 202., ..., 129.,  42., 134.],
        ...,
        [ 79.,   2.,  21., ..., 130., 147.,  86.],
        [161., 121., 112., ..., 115., 228.,   0.],
        [217., 162.,  31., ...,   0.,  69., 246.]], dtype=fl