In [None]:
%load_ext autoreload
%autoreload 2
%load_ext jupyter_black

In [128]:
from __future__ import annotations

try:
    __file__ # type: ignore
except NameError:
    __file__ = __vsc_ipynb_file__  # type: ignore
import os

os.environ["PATH_TO_SEVIR"] = "/mnt/nuc/c/sevir"
import logging
import enum
import typing
from typing_extensions import Self, Any

import xarray as xr
import pandas as pd
import numpy as np
import numpy.typing as npt


try:
    from tqdm import tqdm as RW
except ImportError:
    logging.info("You need to install tqdm to use progress bar")
    RW = list


from src import (
    ID,
    FILE_NAME,
    FILE_INDEX,
    IMG_TYPE,
    TIME_UTC,
    MINUTE_OFFSETS,
    EPISODE_ID,
    EVENT_ID,
    EVENT_TYPE,
    LL_LAT,
    LL_LON,
    UR_LAT,
    UR_LON,
    PROJ,
    SIZE_X,
    SIZE_Y,
    HEIGHT_M,
    WIDTH_M,
    DATA_MIN,
    DATA_MAX,
    PCT_MISSING,
    SEVIR_DTYPES,
    VISIBLE,
    IR_069,
    IR_107,
    VERTICALLY_INTEGRATED_LIQUID,
    LIGHTNING,
    DEFAULT_N_FRAMES,
    DEFAULT_FRAME_TIMES,
    PATH_TO_SEVIR,
    DEFAULT_CATALOG,
    DEFAULT_DATA_HOME,
    DEFAULT_N_FRAMES,
    CATALOG_DTYPES,
    SEVIRImageType,
    SEVIRCatalog,
)

assert all(os.path.exists(p) for p in [DEFAULT_CATALOG, DEFAULT_DATA_HOME])

In [129]:
EVENT_INDEX = "event_index"
H5_FILE = "h5_file"
from typing import Callable
import h5py
import multiprocessing.pool

if typing.TYPE_CHECKING:
    from pandas._typing import Scalar, IndexType, MaskType, HashableT
    from pandas.core.indexing import _IndexSliceTuple as IndexSliceTuple
else:
    HashableT = typing.TypeVar("HashableT", bound=typing.Hashable)
    Scalar = Any
    IndexType = Any
    MaskType = Any
    IndexSliceTuple = Any

In [131]:
LocIndexer: typing.TypeAlias = """(
    int
    | IndexType 
    | MaskType 
    | Callable[[pd.DataFrame], IndexType | MaskType | list[HashableT]] 
    | list[HashableT] 
    | tuple[IndexType | MaskType | list[HashableT] | slice | IndexSliceTuple | Callable, list[HashableT] | slice | pd.Series[bool] | Callable]
)"""


def html(obj) -> str:
    if not hasattr(obj, "_repr_html_"):
        return repr(obj)
    return obj._repr_html_()


class SEVIRBase:
    __slots__ = ("_values", "_index")
    if typing.TYPE_CHECKING:
        _values: pd.DataFrame
        _index: pd.MultiIndex

    def __init__(self, values: pd.DataFrame) -> None:
        self._values = values
        self._index = typing.cast(pd.MultiIndex, values.index)

    @property
    def values(self) -> pd.DataFrame:
        return self._values.copy()

    @property
    def index(self) -> pd.MultiIndex:
        return self._index.copy()

    @property
    def loc(self):
        return SEVIRBase(self._values)

    def __getitem__(self, idx: LocIndexer) -> pd.DataFrame:
        """I generic getitem method that returns the pandas DataFrame"""
        # because the index is a multi index the type checker does not understand that a single
        # integer will still return a pandas DataFrame and not a Series
        return self._values.loc.__getitem__(idx)  # type: ignore

    def __repr__(self) -> str:
        return repr(self._values)

    def _repr_html_(self) -> str:
        return html(self._values)


class H5FileSeries(typing.Mapping):
    def __init__(self, file_names: pd.Series[str]) -> None:
        self._names = file_names
        self.index = index = file_names.index
        files = {}
        ds = xr.Dataset()

        with multiprocessing.pool.ThreadPool() as pool:
            for (idx, stype), file in zip(index, pool.imap(h5py.File, file_names)):
                files[(idx, stype)] = file
                # h5ds[(idx, stype)] = file[stype]

        self._files = pd.Series(files)
        self._h5ds = ds
        # self._files: pd.Series[h5py.File] = file_names.map(h5py.File)  # type: ignore
        # self._h5ds
        self._is_open = True

    def __getitem__(self, key) -> h5py.File | typing.Iterable[h5py.File]:  # type: ignore
        return self._h5ds[key]

    def __iter__(self):
        return iter(self._h5ds)

    def __len__(self):
        return len(self._h5ds)

    def close(self) -> None:
        for idx in self.index:
            self._files.pop(idx).close()
        self._is_open = False


class SEVIRCatalog(SEVIRBase):
    def __init__(
        self,
        catalog: str | pd.DataFrame,
        *,
        image_types: set[SEVIRImageType],
        shuffle: int | None = None,
    ) -> None:
        if isinstance(catalog, pd.DataFrame):
            df = catalog
        else:
            df = self._read_catalog(catalog, image_types)
        if shuffle:
            df = df.sample(frac=1)
        super().__init__(df.sort_index(axis=0, level=EVENT_INDEX))

    def _read_catalog(
        self, catalog: str, image_types: set[SEVIRImageType]
    ) -> pd.DataFrame:
        df = (
            pd.read_csv(
                catalog,
                parse_dates=[TIME_UTC],
                low_memory=False,
                dtype=CATALOG_DTYPES,
            )
            .drop(columns=[PROJ])
            .drop_duplicates()
        )
        # remove all rows that don't have the selected image types
        df = df.loc[df[IMG_TYPE].isin(image_types)]
        # the ID columns is a string with either a "S" or "R" prefix.
        df[ID] = df[ID].str.slice(1).astype(int)
        # set the index to the ID column
        df.set_index([ID], inplace=True)
        # group all of the files by their ID, and remove any where there are not complete set of image types
        mask = df.groupby(ID)[IMG_TYPE].size() == len(set(image_types))
        # mask out the index to only include the IDs that have all of the image types
        df = df.loc[mask[mask].index, :]
        ids = df.index.get_level_values(ID).to_frame(index=False)
        df.index = pd.Index(
            ids.assign(event_index=ids.groupby(ID).ngroup()).event_index,
            name=EVENT_INDEX,
        )

        return df.set_index(IMG_TYPE, append=True)

    @property
    def files(self) -> pd.Series[str]:
        return self._values[FILE_NAME]

    def prefix(self, prefix: str, /) -> SEVIRCatalog:
        self._values[FILE_NAME] = [
            os.path.join(prefix, p) for p in self._values[FILE_NAME]
        ]
        return self

    def validate_paths(self) -> SEVIRCatalog:
        for file in self.get_paths():
            if not os.path.exists(file):
                raise FileNotFoundError(file)
        return self

    def get_paths(self) -> pd.Series[str]:
        return self._values[FILE_NAME]

    def split(
        self,
        x: list[SEVIRImageType],
        y: list[SEVIRImageType],
    ) -> tuple[SEVIRCatalog, SEVIRCatalog]:
        x_cat = SEVIRCatalog(self.loc[pd.IndexSlice[:, x], :], image_types=set(x))
        y_cat = SEVIRCatalog(self.loc[pd.IndexSlice[:, y], :], image_types=set(y))
        return x_cat, y_cat


class SEVIRLoader:
    inputs: SEVIRCatalog
    features: SEVIRCatalog

    def __init__(
        self,
        inputs: list[SEVIRImageType],
        features: list[SEVIRImageType],
        *,
        path_to_sevir: str | None = None,
        catalog: str = "CATALOG.csv",
        shuffle: int = 0,
        validate_paths: bool = False,
        batch_size: int = 32,
    ) -> None:
        cat = SEVIRCatalog(catalog, image_types=set(inputs + features), shuffle=shuffle)

        if path_to_sevir is not None:
            cat.prefix(path_to_sevir)
        if validate_paths:
            cat.validate_paths()

        self.index = cat.index.get_level_values(EVENT_INDEX).unique()
        self.inputs, self.features = cat.split(inputs, features)
        self.batch_size = batch_size

        # self._h5fs = (
        #     self.inputs.files.pipe(H5FileSeries),
        #     self.features.files.pipe(H5FileSeries),
        # )

    @property
    def h5fs(self) -> pd.Series[h5py.File, h5.Dataset]:  # type: ignore
        return self._h5fs  # type: ignore

    # def read(self) -> SEVIRCatalog:
    #     self._h5fm = self._values[FILE_NAME].map(h5py.File).to_dict()
    #     return self

    # def close(self) -> None:
    #     for fm in self._h5fs:
    #         fm.close()

    def get_batch(self, idx: LocIndexer) -> tuple[H5FileSeries, H5FileSeries]:
        inputs, features = self.h5fs

        return inputs[idx], features[idx]
        # inputs = self.inputs.loc[idx]
        # features = self.features.loc[idx]
        # return inputs[FILE_NAME], features[FILE_NAME]

    def __getitem__(
        self, idx: int | list[int] | slice
    ) -> tuple[H5FileSeries, H5FileSeries]:
        return self.get_batch(idx)

    def __len__(self) -> int:
        return len(self.index)

    def _repr_html_(self):
        return f"""\
<h3>inputs</h3>
{html(self.inputs)}
<h3>features</h3>
{html(self.features)}"""

    def __repr__(self):
        return f"""\
[inputs]
{repr(self.inputs)}
[features]
{repr(self.features)}"""

    def __iter__(
        self,
    ) -> typing.Generator[tuple[pd.Series, pd.Series], None, None]:
        yield from (self.get_batch(idx) for idx in self.index)


loader = SEVIRLoader(
    [VISIBLE],
    [VERTICALLY_INTEGRATED_LIQUID],
    path_to_sevir=DEFAULT_DATA_HOME,
    catalog=DEFAULT_CATALOG,
    shuffle=0,
    validate_paths=True,
)

loader

Unnamed: 0_level_0,Unnamed: 1_level_0,file_name,file_index,time_utc,minute_offsets,episode_id,event_id,event_type,llcrnrlat,llcrnrlon,urcrnrlat,urcrnrlon,size_x,size_y,height_m,width_m,data_min,data_max,pct_missing
event_index,img_type,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1
0,vis,/mnt/nuc/c/sevir/data/vis/2018/SEVIR_VIS_STORM...,0,2018-01-21 23:14:00,-119:-114:-109:-104:-99:-94:-89:-84:-79:-74:-6...,121603.0,727819,Hail,31.174610,-99.007301,34.593082,-94.854282,768,768,384000.0,384000.0,-0.003144,0.936059,0.0
1,vis,/mnt/nuc/c/sevir/data/vis/2018/SEVIR_VIS_STORM...,3,2018-01-22 03:26:00,-121:-116:-111:-106:-101:-96:-91:-86:-81:-76:-...,121701.0,728503,Tornado,31.748748,-96.753871,35.079995,-92.480682,768,768,384000.0,384000.0,-0.003644,0.020730,0.0
2,vis,/mnt/nuc/c/sevir/data/vis/2018/SEVIR_VIS_STORM...,4,2018-02-06 20:06:00,-121:-116:-111:-106:-101:-96:-91:-86:-81:-76:-...,121968.0,730231,Hail,30.814419,-97.017593,34.158585,-92.804324,768,768,384000.0,384000.0,0.070697,1.169730,0.0
3,vis,/mnt/nuc/c/sevir/data/vis/2018/SEVIR_VIS_STORM...,20,2018-02-10 22:55:00,-120:-115:-110:-105:-100:-95:-90:-85:-80:-75:-...,122007.0,730443,Flood,36.146405,-84.414905,38.931600,-79.347818,768,768,384000.0,384000.0,-0.003386,0.555044,0.0
4,vis,/mnt/nuc/c/sevir/data/vis/2018/SEVIR_VIS_STORM...,11,2018-02-11 19:13:00,-118:-113:-108:-103:-98:-93:-88:-83:-78:-73:-6...,122033.0,730587,Flood,37.855918,-78.668305,40.349141,-73.223167,768,768,384000.0,384000.0,0.027006,0.740143,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
12751,vis,/mnt/nuc/c/sevir/data/vis/2019/SEVIR_VIS_RANDO...,214,2019-11-30 18:44:00,-120:-115:-110:-105:-100:-95:-90:-85:-80:-75:-...,,,,34.844928,-91.123460,37.936701,-86.429825,768,768,384000.0,384000.0,0.030694,1.156246,0.0
12752,vis,/mnt/nuc/c/sevir/data/vis/2019/SEVIR_VIS_RANDO...,213,2019-11-30 18:44:00,-120:-115:-110:-105:-100:-95:-90:-85:-80:-75:-...,,,,36.192798,-85.468485,39.026754,-80.440684,768,768,384000.0,384000.0,0.039976,0.763576,0.0
12753,vis,/mnt/nuc/c/sevir/data/vis/2019/SEVIR_VIS_RANDO...,112,2019-11-30 18:44:00,-120:-115:-110:-105:-100:-95:-90:-85:-80:-75:-...,,,,36.064382,-88.163621,39.021576,-83.257717,768,768,384000.0,384000.0,0.043733,1.094010,0.0
12754,vis,/mnt/nuc/c/sevir/data/vis/2019/SEVIR_VIS_RANDO...,108,2019-11-30 18:44:00,-120:-115:-110:-105:-100:-95:-90:-85:-80:-75:-...,,,,40.973362,-73.972739,43.191927,-68.026592,768,768,384000.0,384000.0,0.006576,0.347847,0.0

Unnamed: 0_level_0,Unnamed: 1_level_0,file_name,file_index,time_utc,minute_offsets,episode_id,event_id,event_type,llcrnrlat,llcrnrlon,urcrnrlat,urcrnrlon,size_x,size_y,height_m,width_m,data_min,data_max,pct_missing
event_index,img_type,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1
0,vil,/mnt/nuc/c/sevir/data/vil/2018/SEVIR_VIL_STORM...,141,2018-01-21 23:14:00,-119:-114:-109:-104:-99:-94:-89:-84:-79:-74:-6...,121603.0,727819,Hail,31.174610,-99.007301,34.593082,-94.854282,384,384,384000.0,384000.0,0.0,254.0,0.000000
1,vil,/mnt/nuc/c/sevir/data/vil/2018/SEVIR_VIL_STORM...,0,2018-01-22 03:26:00,-121:-116:-111:-106:-101:-96:-91:-86:-81:-76:-...,121701.0,728503,Tornado,31.748748,-96.753871,35.079995,-92.480682,384,384,384000.0,384000.0,0.0,254.0,0.000000
2,vil,/mnt/nuc/c/sevir/data/vil/2018/SEVIR_VIL_STORM...,600,2018-02-06 20:06:00,-121:-116:-111:-106:-101:-96:-91:-86:-81:-76:-...,121968.0,730231,Hail,30.814419,-97.017593,34.158585,-92.804324,384,384,384000.0,384000.0,0.0,254.0,0.061224
3,vil,/mnt/nuc/c/sevir/data/vil/2018/SEVIR_VIL_STORM...,205,2018-02-10 22:55:00,-120:-115:-110:-105:-100:-95:-90:-85:-80:-75:-...,122007.0,730443,Flood,36.146405,-84.414905,38.931600,-79.347818,384,384,384000.0,384000.0,0.0,188.0,0.000000
4,vil,/mnt/nuc/c/sevir/data/vil/2018/SEVIR_VIL_STORM...,665,2018-02-11 19:13:00,-118:-113:-108:-103:-98:-93:-88:-83:-78:-73:-6...,122033.0,730587,Flood,37.855918,-78.668305,40.349141,-73.223167,384,384,384000.0,384000.0,0.0,241.0,0.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
12751,vil,/mnt/nuc/c/sevir/data/vil/2019/SEVIR_VIL_RANDO...,2213,2019-11-30 18:46:15,-120:-115:-110:-105:-100:-95:-90:-85:-80:-75:-...,,,,34.844928,-91.123460,37.936701,-86.429825,384,384,384000.0,384000.0,0.0,254.0,0.000000
12752,vil,/mnt/nuc/c/sevir/data/vil/2019/SEVIR_VIL_RANDO...,2212,2019-11-30 18:46:15,-120:-115:-110:-105:-100:-95:-90:-85:-80:-75:-...,,,,36.192798,-85.468485,39.026754,-80.440684,384,384,384000.0,384000.0,0.0,200.0,0.000000
12753,vil,/mnt/nuc/c/sevir/data/vil/2019/SEVIR_VIL_RANDO...,2211,2019-11-30 18:46:15,-120:-115:-110:-105:-100:-95:-90:-85:-80:-75:-...,,,,36.064382,-88.163621,39.021576,-83.257717,384,384,384000.0,384000.0,0.0,254.0,0.000000
12754,vil,/mnt/nuc/c/sevir/data/vil/2019/SEVIR_VIL_RANDO...,2207,2019-11-30 18:46:15,-120:-115:-110:-105:-100:-95:-90:-85:-80:-75:-...,,,,40.973362,-73.972739,43.191927,-68.026592,384,384,384000.0,384000.0,0.0,31.0,0.000000


In [157]:
import multiprocessing.pool


def read(idx: tuple[int, str], r: dict[str, typing.Any]):
    with h5py.File(r["file_name"], "r") as f:
        return idx, np.array(f[idx[1]])
        # ds[i, t] = ("N", "L", "W", "T"), np.array(f[t])


def h5toxarray(df: pd.DataFrame) -> xr.Dataset:
    records = typing.cast(
        dict[tuple[int, str], dict[str, typing.Any]], df.to_dict("index")
    )
    ds = xr.Dataset(attrs=records)
    with multiprocessing.pool.ThreadPool(8) as pool:
        results = pool.starmap(read, records.items())

    for idx, data in enumerate(results):
        ds[idx] = ("N", "L", "W", "T"), data

    return ds


ds = h5toxarray(loader.inputs._values)
ds

KeyboardInterrupt: 

: 

In [None]:
import xarray as xr

h5ds = typing.cast(h5py.Dataset, loader.h5fs[0]._h5ds[0, "vis"])

# pd.Series([loader.h5fs[0]._h5ds[0, "vis"]])
a = np.array(h5ds)

In [None]:
# xr.DataArray(a, dims=["N", "L", "W", "T"])
ds = xr.Dataset()
ds["hello"] = ("N", "L", "W", "T"), a  # xr.DataArray(a, dims=["N", "L", "W", "T"])
ds

In [None]:
# , coords=("event", "n", "x", "y", "t")
# ds = xr.Dataset(
#     # data_vars={
#     #     "vis": (["e", "n", "x", "y", "t"], [[[[[]]]]]),
#     # },
#     # coords={"e": []},
# )
# # coord = "e", "n", "x", "y", "t"
# ds["vis"] = (coord), [[[[[]]]]]
# xr.Dataset(
#     {
#         (i, "vis"): (["n", "x", "y", "t"], np.array(loader.h5fs[0]._h5ds[(i, "vis")]))
#         for i in range(2)
#     }
# )
# h5 = loader.h5fs[0]._h5ds[(i, "vis")]
# ds["vis"].loc[{"e": i}] = ("n", "x", "y", "t"), np.array(h5)  # [:, np.newaxis]

#     # a = np.array(h5)
#     # ds["vis"] = (
#     #     (i, "N"),
#     #     (i, "L"),
#     #     (i, "W"),
#     #     (i, "T"),
#     # ), a
#     # xr.DataArray(a, dims=["N", "L", "W", "T"])
# ds
# print(a1.shape, a2.shape)
# ds

import multiprocessing.pool

ds = xr.Dataset()


with multiprocessing.pool.ThreadPool(4) as pool:

    def f(i):
        return (i, "vis"), np.array(loader.h5fs[0]._h5ds[(i, "vis")])

    for k, v in pool.map(f, range(2)):
        ds[k] = ("n", "x", "y", "t"), v

ds

In [None]:
x, y = loader[1:10]
x

In [None]:
pd.Series({(index, stype): h5[stype][0:1, :, :, :] for (index, stype), h5 in x.items()})

In [None]:
print(" ".join(x for x in dir(x[1, "vis"]["vis"]) if not x.startswith("_")))

from h5py import Dataset
type(x[1, "vis"]["vis"])

In [None]:
s[0]

In [None]:
x, y = loader[0]
x

In [None]:
ids = loader.cat.index.get_level_values(ID).to_frame(index=False)  # .to_numpy()
# _, a = np.unique(ids, return_index=True)
# len(a), len(ids)
# event_id = np.unique(ids)
# arr = np.arange(len(event_id))
# # arr = np.arange(len(ids))
# # np.where(ids == arr)
# # arr
# # create an array of the same length as the index to map the unique IDs to

# # arr[arr[:, np.newaxis] == np.arange(len(ids))]in

# mask = arr == ids[np.newaxis, :]
# mask
# ids.ngroups()ids
ids.assign(index=ids.groupby(ID).ngroup()).set_index("index")

In [None]:
for feature, inputs in loader:
    print(feature, inputs)
    break

In [30]:
import zarr
import numpy as np
# arr = zarr.array(np.arange(1000), chunks=(100,))
store = zarr.DirectoryStore('data/array.zarr')
root = zarr.group(store=store, overwrite=True)
if not 'foo' in root:
    foo = root.create_group('vis')
else:
    foo = root["foo"]
    
# if not root.foo:
# foo.array('bar', arr, chunks=(100,))
foo["bar"]

# list(foo.arrays())

KeyError: 'bar'