In [1]:
%load_ext autoreload
%autoreload 2
%load_ext jupyter_black

In [2]:
from __future__ import annotations

try:
    __file__
except NameError:
    __file__ = __vsc_ipynb_file__  # type: ignore
import enum
import os
import typing
from typing import Any, Self
import typing
import pandas as pd
import numpy as np
import numpy.typing as npt


_ROOT_DIR = os.path.abspath(os.path.dirname(os.path.abspath(__file__)))
PATH_TO_SEVIR = os.getenv("PATH_TO_SEVIR", _ROOT_DIR)
DEFAULT_CATALOG = os.path.join(PATH_TO_SEVIR, "CATALOG.csv")
DEFAULT_DATA_HOME = os.path.join(PATH_TO_SEVIR, "data")
DEFAULT_N_FRAMES = 49  # TODO:  don't hardcode this
# Nominal Frame time offsets in minutes (used for non-raster types)

DEFAULT_FRAME_TIMES = np.arange(-120.0, 125.0, 5) * 60  # in seconds
"""The lightning flashes in each from will represent the 5 minutes leading up the
the frame's time EXCEPT for the first frame, which will use the same flashes as the second frame
(This will be corrected in a future version of SEVIR so that all frames are consistent)"""


class Enum(enum.Enum):
    @classmethod
    def _missing_(cls, value: object) -> Any:
        return cls.__members__[str(value).upper()]

    @classmethod
    def map(cls, __values: typing.Sequence[typing.Any], /) -> list[Self]:
        """class method to map values to enum members"""
        return [
            cls(value)
            for value in ([__values] if isinstance(__values, (str, Enum)) else __values)
        ]


class SEVIRTypes(str, Enum):
    VISIBLE = "vis"
    IR_069 = "ir069"
    IR_107 = "ir107"
    VERTICALLY_INTEGRATED_LIQUID = "vil"
    LIGHTNING = "lght"

    def get_dtype(self) -> npt.DTypeLike:
        return SEVIR_DTYPES[self]

    def get_cmap(self) -> typing.Any:
        raise NotImplementedError


SEVIR_DTYPES: dict[SEVIRTypes, npt.DTypeLike] = {
    SEVIRTypes.VERTICALLY_INTEGRATED_LIQUID: np.uint8,
    SEVIRTypes.VISIBLE: np.int16,
    SEVIRTypes.IR_069: np.int16,
    SEVIRTypes.IR_107: np.int16,
    SEVIRTypes.LIGHTNING: np.int16,
}

VISIBLE, IR_069, IR_107, VERTICALLY_INTEGRATED_LIQUID, LIGHTNING = (
    SEVIRTypes.VISIBLE,
    SEVIRTypes.IR_069,
    SEVIRTypes.IR_107,
    SEVIRTypes.VERTICALLY_INTEGRATED_LIQUID,
    SEVIRTypes.LIGHTNING,
)

IMAGE_TYPE = "img_type"
CATALOG_COLUMNS = (
    ID,
    FILE_NAME,
    FILE_INDEX,
    IMG_TYPE,
    TIME_UTC,
    MINUTE_OFFSETS,
    EPISODE_ID,
    EVENT_ID,
    EVENT_TYPE,
    LL_LAT,
    LL_LON,
    UR_LAT,
    UR_LON,
    PROJ,
    SIZE_X,
    SIZE_Y,
    HEIGHT_M,
    WIDTH_M,
    DATA_MIN,
    DATA_MAX,
    PCT_MISSING,
) = (
    "id",
    "file_name",
    "file_index",
    "img_type",
    "time_utc",
    "minute_offsets",
    "episode_id",
    "event_id",
    "event_type",
    "llcrnrlat",
    "llcrnrlon",
    "urcrnrlat",
    "urcrnrlon",
    "proj",
    "size_x",
    "size_y",
    "height_m",
    "width_m",
    "data_min",
    "data_max",
    "pct_missing",
)

CATALOG_DTYPES = {
    ID: "string",
    EVENT_ID: "Int64",
    IMAGE_TYPE: "string",
    PROJ: "string",
    FILE_NAME: "string",
    EVENT_TYPE: "string",
}

In [3]:
def html(obj) -> str:
    if not hasattr(obj, "_repr_html_"):
        return repr(obj)
    return obj._repr_html_()


class SEVIRBase:
    __slots__ = ("_values", "_index")
    _values: pd.DataFrame
    _index: pd.MultiIndex

    def __init__(self, values: pd.DataFrame) -> None:
        self._values = values
        self._index = values.index

    @property
    def values(self) -> pd.DataFrame:
        return self._values.copy()

    @property
    def index(self) -> pd.MultiIndex:
        return self._index.copy()

    @property
    def loc(self):
        return SEVIRBase(self._values)

    def __repr__(self) -> str:
        return repr(self._values)

    def _repr_html_(self) -> str:
        return html(self._values)

    def __getitem__(self, idx: int | list[int] | slice) -> pd.DataFrame:
        return self._values.loc[idx]


class SEVIRCatalog(SEVIRBase):
    def __init__(
        self,
        catalog: str | pd.DataFrame,
        *,
        image_types: set[SEVIRTypes],
        shuffle: int | None = None,
    ) -> None:
        if isinstance(catalog, pd.DataFrame):
            df = catalog
        else:
            df = self._read_catalog(catalog, image_types)
        if shuffle:
            df = df.sample(frac=1)
        super().__init__(df)

    def _read_catalog(self, catalog: str, image_types: set[SEVIRTypes]) -> pd.DataFrame:
        df = (
            pd.read_csv(
                catalog,
                parse_dates=[TIME_UTC],
                low_memory=False,
                dtype=CATALOG_DTYPES,
            )
            .drop(columns=[PROJ])
            .drop_duplicates()
        )
        # remove all rows that don't have the selected image types
        df = df.loc[df[IMG_TYPE].isin(image_types)]
        # the ID columns is a string with either a "S" or "R" prefix.
        df[ID] = df[ID].str.slice(1).astype(int)
        # set the index to the ID column
        df.set_index([ID], inplace=True)
        # group all of the files by their ID, and remove any where there are not complete set of image types
        mask = df.groupby(ID)[IMG_TYPE].size() == len(set(image_types))
        # lastly mask out the index to only include the IDs that have all of the image types
        return df.loc[mask[mask].index, :].set_index(IMG_TYPE, append=True)

    def prefix(self, o: str) -> SEVIRCatalog:
        self._values[FILE_NAME] = [os.path.join(o, p) for p in self._values[FILE_NAME]]
        return self

    def get_paths(self) -> pd.Series[str]:
        return self._values[FILE_NAME]

    def validate_paths(self) -> SEVIRCatalog:
        for file in self.get_paths():
            if not os.path.exists(file):
                raise FileNotFoundError(file)

    def split(
        self,
        x: list[SEVIRTypes],
        y: list[SEVIRTypes],
    ) -> tuple[SEVIRCatalog, SEVIRCatalog]:
        x_df = self.loc[pd.IndexSlice[:, x], :]
        y_df = self.loc[pd.IndexSlice[:, y], :]
        return SEVIRCatalog(x_df, image_types=set(x)), SEVIRCatalog(
            y_df, image_types=set(y)
        )


class SEVIRLoader:
    inputs: SEVIRCatalog
    features: SEVIRCatalog

    def __init__(
        self,
        inputs: list[SEVIRTypes],
        features: list[SEVIRTypes],
        *,
        path_to_sevir: str | None = None,
        catalog: str = "CATALOG.csv",
        shuffle: int = 0,
        validate_paths: bool = False,
        batch_size: int = 32,
    ) -> None:
        cat = SEVIRCatalog(catalog, image_types=set(inputs + features), shuffle=shuffle)

        if path_to_sevir is not None:
            cat.prefix(path_to_sevir)
        if validate_paths:
            cat = cat.validate_paths()

        # TODO: reindex the DataFrame so that the ID's are [0, ..., len(cat.index.get_level_values(ID).unique()) - 1]
        self.index = cat.index.get_level_values(ID).unique()
        self.inputs, self.features = cat.split(inputs, features)
        self.cat = cat
        self.batch_size = batch_size

    def get_batch(self, idx: int) -> tuple[pd.Series, pd.Series]:
        inputs = self.inputs.loc[idx, :]
        features = self.features.loc[idx, :]
        return inputs[FILE_NAME], features[FILE_NAME]

    def __getitem__(self, idx: int) -> tuple[pd.Series, pd.Series]:
        return self.get_batch(self.index[idx])

    def __len__(self) -> int:
        return len(self.index)

    def _repr_html_(self):
        return f"""\
<h3>inputs</h3>
{html(self.inputs)}
<h3>features</h3>
{html(self.features)}"""

    def __repr__(self):
        return f"""\
[inputs]
{repr(self.inputs)}
[features]
{repr(self.features)}"""

    def __iter__(self) -> tuple[pd.Series, pd.Series]:
        yield from (self.get_batch(idx)[FILE_NAME] for idx in self.index)


loader = SEVIRLoader(
    [SEVIRTypes.VISIBLE, SEVIRTypes.IR_069, SEVIRTypes.IR_107],
    [SEVIRTypes.VERTICALLY_INTEGRATED_LIQUID, SEVIRTypes.LIGHTNING],
    path_to_sevir="/home/jupyter/data/sevir",
    catalog="CATALOG.csv",
    shuffle=0,
)
loader

Unnamed: 0_level_0,Unnamed: 1_level_0,file_name,file_index,time_utc,minute_offsets,episode_id,event_id,event_type,llcrnrlat,llcrnrlon,urcrnrlat,urcrnrlon,size_x,size_y,height_m,width_m,data_min,data_max,pct_missing
id,img_type,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1
730231,vis,/home/jupyter/data/sevir/vis/2018/SEVIR_VIS_ST...,4,2018-02-06 20:06:00,-121:-116:-111:-106:-101:-96:-91:-86:-81:-76:-...,121968.0,730231,Hail,30.814419,-97.017593,34.158585,-92.804324,768,768,384000.0,384000.0,0.070697,1.169730,0.0
730443,vis,/home/jupyter/data/sevir/vis/2018/SEVIR_VIS_ST...,20,2018-02-10 22:55:00,-120:-115:-110:-105:-100:-95:-90:-85:-80:-75:-...,122007.0,730443,Flood,36.146405,-84.414905,38.931600,-79.347818,768,768,384000.0,384000.0,-0.003386,0.555044,0.0
730587,vis,/home/jupyter/data/sevir/vis/2018/SEVIR_VIS_ST...,11,2018-02-11 19:13:00,-118:-113:-108:-103:-98:-93:-88:-83:-78:-73:-6...,122033.0,730587,Flood,37.855918,-78.668305,40.349141,-73.223167,768,768,384000.0,384000.0,0.027006,0.740143,0.0
731013,vis,/home/jupyter/data/sevir/vis/2018/SEVIR_VIS_ST...,14,2018-02-10 22:50:00,-120:-115:-110:-105:-100:-95:-90:-85:-80:-75:-...,122115.0,731013,Tornado,31.197225,-91.317664,34.312541,-86.864100,768,768,384000.0,384000.0,-0.003082,0.914687,0.0
731511,vis,/home/jupyter/data/sevir/vis/2018/SEVIR_VIS_ST...,2,2018-02-12 19:30:00,-120:-115:-110:-105:-100:-95:-90:-85:-80:-75:-...,122095.0,731511,Tornado,34.582982,-121.120933,38.698921,-117.849728,768,768,384000.0,384000.0,0.025700,0.949320,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
19113018467785,ir107,/home/jupyter/data/sevir/ir107/2019/SEVIR_IR10...,460,2019-11-30 18:44:00,-120:-115:-110:-105:-100:-95:-90:-85:-80:-75:-...,,,,34.844928,-91.123460,37.936701,-86.429825,192,192,384000.0,384000.0,-66.673363,14.001942,0.0
19113018467861,ir107,/home/jupyter/data/sevir/ir107/2019/SEVIR_IR10...,459,2019-11-30 18:44:00,-120:-115:-110:-105:-100:-95:-90:-85:-80:-75:-...,,,,36.192798,-85.468485,39.026754,-80.440684,192,192,384000.0,384000.0,-63.983356,13.186085,0.0
19113018467863,ir107,/home/jupyter/data/sevir/ir107/2019/SEVIR_IR10...,458,2019-11-30 18:44:00,-120:-115:-110:-105:-100:-95:-90:-85:-80:-75:-...,,,,36.064382,-88.163621,39.021576,-83.257717,192,192,384000.0,384000.0,-67.114944,10.598397,0.0
19113018468164,ir107,/home/jupyter/data/sevir/ir107/2019/SEVIR_IR10...,454,2019-11-30 18:44:00,-120:-115:-110:-105:-100:-95:-90:-85:-80:-75:-...,,,,40.973362,-73.972739,43.191927,-68.026592,192,192,384000.0,384000.0,-21.303188,11.281363,0.0

Unnamed: 0_level_0,Unnamed: 1_level_0,file_name,file_index,time_utc,minute_offsets,episode_id,event_id,event_type,llcrnrlat,llcrnrlon,urcrnrlat,urcrnrlon,size_x,size_y,height_m,width_m,data_min,data_max,pct_missing
id,img_type,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1
730231,vil,/home/jupyter/data/sevir/vil/2018/SEVIR_VIL_ST...,600,2018-02-06 20:06:00,-121:-116:-111:-106:-101:-96:-91:-86:-81:-76:-...,121968.0,730231,Hail,30.814419,-97.017593,34.158585,-92.804324,384,384,384000.0,384000.0,0.0,254.0,0.061224
730443,vil,/home/jupyter/data/sevir/vil/2018/SEVIR_VIL_ST...,205,2018-02-10 22:55:00,-120:-115:-110:-105:-100:-95:-90:-85:-80:-75:-...,122007.0,730443,Flood,36.146405,-84.414905,38.931600,-79.347818,384,384,384000.0,384000.0,0.0,188.0,0.000000
730587,vil,/home/jupyter/data/sevir/vil/2018/SEVIR_VIL_ST...,665,2018-02-11 19:13:00,-118:-113:-108:-103:-98:-93:-88:-83:-78:-73:-6...,122033.0,730587,Flood,37.855918,-78.668305,40.349141,-73.223167,384,384,384000.0,384000.0,0.0,241.0,0.000000
731013,vil,/home/jupyter/data/sevir/vil/2018/SEVIR_VIL_ST...,82,2018-02-10 22:50:00,-120:-115:-110:-105:-100:-95:-90:-85:-80:-75:-...,122115.0,731013,Tornado,31.197225,-91.317664,34.312541,-86.864100,384,384,384000.0,384000.0,0.0,221.0,0.000000
731511,vil,/home/jupyter/data/sevir/vil/2018/SEVIR_VIL_ST...,351,2018-02-12 19:30:00,-120:-115:-110:-105:-100:-95:-90:-85:-80:-75:-...,122095.0,731511,Tornado,34.582982,-121.120933,38.698921,-117.849728,384,384,384000.0,384000.0,0.0,174.0,0.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
19113018467785,lght,/home/jupyter/data/sevir/lght/2019/SEVIR_LGHT_...,0,2019-11-30 18:44:00,,,,,34.844928,-91.123460,37.936701,-86.429825,48,48,384000.0,384000.0,0.0,383499.0,0.000000
19113018467861,lght,/home/jupyter/data/sevir/lght/2019/SEVIR_LGHT_...,0,2019-11-30 18:44:00,,,,,36.192798,-85.468485,39.026754,-80.440684,48,48,384000.0,384000.0,0.0,383499.0,0.000000
19113018467863,lght,/home/jupyter/data/sevir/lght/2019/SEVIR_LGHT_...,0,2019-11-30 18:44:00,,,,,36.064382,-88.163621,39.021576,-83.257717,48,48,384000.0,384000.0,0.0,383499.0,0.000000
19113018468164,lght,/home/jupyter/data/sevir/lght/2019/SEVIR_LGHT_...,0,2019-11-30 18:44:00,,,,,40.973362,-73.972739,43.191927,-68.026592,48,48,384000.0,384000.0,0.0,383499.0,0.000000


In [4]:
for feature, inputs in loader:
    print(feature, inputs)
    break

TypeError: tuple indices must be integers or slices, not str