In [11]:
%load_ext jupyter_black

In [96]:
from __future__ import annotations

import enum
import pandas as pd
import abc
from typing import Generic, TypeVar, Any, Callable, Iterable, Iterator, Hashable, Union, overload, TypeAlias, Mapping


EnumT = TypeVar("EnumT", bound=enum.Enum)
Aliases = Mapping[str, list[str]]
_EnumNames: TypeAlias = str | Iterable[str] | Iterable[Iterable[str | Any]] | Mapping[str, Any]


class GenericEnumMeta(enum.EnumMeta, Generic[EnumT]):
    def __class_getitem__(cls, __x: Any) -> TypeAlias:
        return cls




class PandasEnumMeta(GenericEnumMeta[EnumT]):
    __iter__: Callable[..., Iterator[EnumT]]  # type: ignore[assignment]
    _member_map_: pd.DataFrame

    def __new__(
        cls,
        name: str,
        bases: tuple[type, ...],
        cls_dict: enum._EnumDict,
        aliases: Aliases | None = None,
    ):
        obj = super().__new__(cls, name, bases, cls_dict)
        if aliases is None:
            aliases = obj._get_aliases()
        obj._member_map_ = pd.DataFrame.from_dict(
            {k: [v, *aliases.get(k, [])] for k, v in obj._member_map_.items()}, orient="index"
        ).astype("string")
        return obj

    @overload  # type: ignore
    def __getitem__(self, names: str) -> EnumT:
        ...

    @overload
    def __getitem__(self, names: list[str]) -> list[EnumT]:
        ...

    def __getitem__(self, names: str | list[str]) -> EnumT | list[EnumT]:
        x = self._member_map_.loc[names, 0]
        if isinstance(x, pd.Series):
            return x.to_list()
        return x

    def __call__(
        cls,
        value: Any,
        names: _EnumNames | None = None,
        *,
        module: str | None = None,
        qualname: str | None = None,
        type: type | None = None,
        start: int = 1,
    ):
        if (
            isinstance(value, cls.__mro__[:-1]) and names is None
        ):  # could be simple value lookup or functional API: we're creating a new Enum type
            return super().__call__(value, names, module=module, qualname=qualname, type=type, start=start)
        return cls.intersection(value)

    @property
    def values(self) -> pd.Series[EnumT]:
        return self._member_map_.iloc[:, 0]

    def to_list(self) -> list[EnumT]:
        return self.values.to_list()

    def _get_aliases(cls) -> Mapping[str, list[str]]:
        return {}

    def to_series(cls) -> pd.Series[str]:
        return cls._member_map_

    def is_in(cls, __x: Iterable[Any]) -> pd.Series[bool]:
        if isinstance(__x, str):
            __x = [__x]
        return cls.to_series().isin(__x).any(axis=1)

    def difference(cls, __x: Iterable[Any]) -> list[EnumT]:
        return cls[~cls.is_in(__x)]

    def intersection(cls, __x: Iterable[Any]) -> list[EnumT]:
        return cls[cls.is_in(__x)]

    def map(cls, __x: Iterable[Any]) -> Mapping[str, EnumT]:
        return {x: cls.__call__(x) for x in map(str, __x)}


class MyEnum(
    str,
    enum.Enum,
    metaclass=PandasEnumMeta["MyEnum"],  # type: ignore[misc]
):
    A = "a"
    B = "b"
    C = "c"

    @classmethod
    def _get_aliases(cls) -> Aliases:
        return {
            "A": ["alpha", "apple"],
            "B": ["bravo"],
            "C": ["charlie"],
        }


assert MyEnum["A"] == MyEnum.A
assert MyEnum[["A"]] == [MyEnum.A]
assert MyEnum.to_list() == [MyEnum.A, MyEnum.B, MyEnum.C]


MyEnum.values.isin(["a"]), MyEnum.is_in(["a"])
MyEnum(["a"])
mask = MyEnum.is_in(["a", "b"])
MyEnum[["A", "B"]]
MyEnum(["alpha", "a", "C"])

MyEnum._member_map_

Unnamed: 0,0,1,2
A,MyEnum.A,alpha,apple
B,MyEnum.B,bravo,
C,MyEnum.C,charlie,
