In [1]:
%load_ext autoreload
%autoreload 2
%load_ext jupyter_black

In [18]:
from __future__ import annotations

try:
    __file__
except NameError:
    __file__ = __vsc_ipynb_file__  # type: ignore
import enum
import os
import typing
from typing import Any, Self

import pandas as pd
import numpy as np
import numpy.typing as npt



_ROOT_DIR = os.path.abspath(os.path.dirname(os.path.abspath(__file__)))
PATH_TO_SEVIR = os.getenv("PATH_TO_SEVIR", _ROOT_DIR)
DEFAULT_CATALOG = os.path.join(PATH_TO_SEVIR, "CATALOG.csv")
DEFAULT_DATA_HOME = os.path.join(PATH_TO_SEVIR, "data")
DEFAULT_N_FRAMES = 49  # TODO:  don't hardcode this
# Nominal Frame time offsets in minutes (used for non-raster types)

DEFAULT_FRAME_TIMES = np.arange(-120.0, 125.0, 5) * 60  # in seconds
"""The lightning flashes in each from will represent the 5 minutes leading up the
the frame's time EXCEPT for the first frame, which will use the same flashes as the second frame
(This will be corrected in a future version of SEVIR so that all frames are consistent)"""


class Enum(enum.Enum):
    @classmethod
    def _missing_(cls, value: object) -> Any:
        return cls.__members__[str(value).upper()]

    @classmethod
    def map(cls, __values: typing.Sequence[typing.Any], /) -> list[Self]:
        """class method to map values to enum members"""
        return [
            cls(value)
            for value in ([__values] if isinstance(__values, (str, Enum)) else __values)
        ]


class SEVIRTypes(str, Enum):
    VISIBLE = "vis"
    IR_069 = "ir069"
    IR_107 = "ir107"
    VERTICALLY_INTEGRATED_LIQUID = "vil"
    LIGHTNING = "lght"

    def get_dtype(self) -> npt.DTypeLike:
        return SEVIR_DTYPES[self]

    def get_cmap(self) -> typing.Any:
        raise NotImplementedError


SEVIR_DTYPES: dict[SEVIRTypes, npt.DTypeLike] = {
    SEVIRTypes.VERTICALLY_INTEGRATED_LIQUID: np.uint8,
    SEVIRTypes.VISIBLE: np.int16,
    SEVIRTypes.IR_069: np.int16,
    SEVIRTypes.IR_107: np.int16,
    SEVIRTypes.LIGHTNING: np.int16,
}

VISIBLE, IR_069, IR_107, VERTICALLY_INTEGRATED_LIQUID, LIGHTNING = (
    SEVIRTypes.VISIBLE,
    SEVIRTypes.IR_069,
    SEVIRTypes.IR_107,
    SEVIRTypes.VERTICALLY_INTEGRATED_LIQUID,
    SEVIRTypes.LIGHTNING,
)

IMAGE_TYPE = "img_type"
CATALOG_COLUMNS = (
    ID,
    FILE_NAME,
    FILE_INDEX,
    IMG_TYPE,
    TIME_UTC,
    MINUTE_OFFSETS,
    EPISODE_ID,
    EVENT_ID,
    EVENT_TYPE,
    LL_LAT,
    LL_LON,
    UR_LAT,
    UR_LON,
    PROJ,
    SIZE_X,
    SIZE_Y,
    HEIGHT_M,
    WIDTH_M,
    DATA_MIN,
    DATA_MAX,
    PCT_MISSING,
) = (
    "id",
    "file_name",
    "file_index",
    "img_type",
    "time_utc",
    "minute_offsets",
    "episode_id",
    "event_id",
    "event_type",
    "llcrnrlat",
    "llcrnrlon",
    "urcrnrlat",
    "urcrnrlon",
    "proj",
    "size_x",
    "size_y",
    "height_m",
    "width_m",
    "data_min",
    "data_max",
    "pct_missing",
)

In [214]:
class SEVIRBase:
    __slots__ = ("_values", "_index")
    _values: pd.DataFrame

    @property
    def values(self) -> pd.DataFrame:
        return self._values.copy()

    @property
    def index(self) -> pd.MultiIndex:
        return self.values.index

    def __repr__(self) -> str:
        return repr(self._values)

    def _repr_html_(self) -> str:
        return self._values._repr_html_()

    def __getitem__(self, idx: int | list[int] | slice) -> pd.DataFrame:
        return self._values.loc[idx]


class SEVIRCatalog(SEVIRBase):
    dtypes = {
        ID: "string",
        EVENT_ID: "Int64",
        IMAGE_TYPE: "string",
        PROJ: "string",
        FILE_NAME: "string",
        EVENT_TYPE: "string",
    }

    def __init__(
        self,
        catalog: str = "CATALOG.csv",
        image_types: set[SEVIRTypes] = set(SEVIRTypes),
        shuffle: bool = False,
    ) -> None:
        df = self._read_catalog(catalog, image_types)
        if shuffle:
            df = df.sample(frac=1)
        self._values, self._index = df, typing.cast(pd.MultiIndex, df.index)

    def _read_catalog(self, catalog: str, image_types: set[SEVIRTypes]) -> pd.DataFrame:
        df = (
            pd.read_csv(
                catalog,
                parse_dates=[TIME_UTC],
                low_memory=False,
                dtype=self.dtypes,
            )
            .drop(columns=[PROJ])
            .drop_duplicates()
        )
        # remove all rows that don't have the selected image types
        df = df.loc[df[IMG_TYPE].isin(image_types)]
        # the ID columns is a string with either a "S" or "R" prefix.
        df[ID] = df[ID].str.slice(1).astype(int)
        # set the index to the ID column
        df.set_index([ID], inplace=True)
        # group all of the files by their ID, and remove any where there are not complete set of image types
        mask = df.groupby(ID)[IMG_TYPE].size() == len(set(image_types))
        # lastly mask out the index to only include the IDs that have all of the image types
        return df.loc[mask[mask].index, :].set_index(IMG_TYPE, append=True)


class SEVIRLoader:
    inputs: pd.DataFrame
    features: pd.DataFrame

    def __init__(
        self,
        inputs: list[SEVIRTypes],
        features: list[SEVIRTypes],
        /,
    ) -> None:
        cat = SEVIRCatalog(image_types=set(inputs + features))
        self.inputs = cat[pd.IndexSlice[:, inputs], :]
        self.features = cat[pd.IndexSlice[:, features], :]

    def get_batch(self, idx: int) -> tuple[pd.Series, pd.Series]:
        inputs = self.inputs.loc[idx, FILE_NAME]
        features = self.features.loc[idx, FILE_NAME]
        return inputs, features


loader = SEVIRLoader(
    [SEVIRTypes.VISIBLE, SEVIRTypes.IR_069, SEVIRTypes.IR_107],
    [SEVIRTypes.VERTICALLY_INTEGRATED_LIQUID, SEVIRTypes.LIGHTNING],
)
loader.get_batch(730231)

(img_type
 vis       vis/2018/SEVIR_VIS_STORMEVENTS_2018_0201_0228.h5
 ir069    ir069/2018/SEVIR_IR069_STORMEVENTS_2018_0101_0...
 ir107    ir107/2018/SEVIR_IR107_STORMEVENTS_2018_0101_0...
 Name: file_name, dtype: string,
 img_type
 vil     vil/2018/SEVIR_VIL_STORMEVENTS_2018_0101_0630.h5
 lght    lght/2018/SEVIR_LGHT_ALLEVENTS_2018_0201_0301.h5
 Name: file_name, dtype: string)