In [1]:
import altair
import attr
import cattr
import matplotlib.pyplot as plt
import yaml
import enum
from IPython.display import display
from sklearn.neighbors import LocalOutlierFactor
from sklearn.datasets import load_iris, load_linnerud
from typing import Optional, Union
from typing_extensions import Literal

In [2]:
cattr = cattr.GenConverter(forbid_extra_keys=True)  # type: ignore

In [3]:
def structure_literal(data, cls):
    if data in cls.__args__:  # cls.__values__ from python 3.8
        return data
    else:
        raise ValueError(f"{data} does not fit Literal {cls}!")


cattr.register_structure_hook_func(
    lambda cls: getattr(cls, "__origin__", None) == Literal,
    structure_literal,
)

# cattr.structure("a", Literal["a"])

In [12]:
def structure_unions_by_try(data, cls):
    for possible_type in cls.__args__:
        try:
            return cattr.structure(data, possible_type)
        except Exception:
            pass
    raise TypeError(f"Could not parse {data} as any of {cls}")


cattr.register_structure_hook_func(
    lambda cls: getattr(cls, "__origin__", None) == Union,
    structure_unions_by_try
)

# cattr.structure("b", Union[Literal["a"], Literal["b"], int])  # type: ignore

'b'

# Nesting & Resolving Ambiguity

In [5]:
@attr.s(auto_attribs=True)
class ScatterPlotSchema:
    kind: Literal["scatter"]
    x: str
    y: str
    z: Optional[str] = None


@attr.s(auto_attribs=True)
class HeatmapSchema:
    kind: Literal["heatmap"]
    x: str
    y: str


PlotSchema = Union[ScatterPlotSchema, HeatmapSchema]


@enum.unique
class Dataset(enum.Enum):
    LINNERUD = "linnerud"
    IRIS = "iris"


@attr.s(auto_attribs=True)
class ConfigSchema:
    dataset: Dataset
    outlier_n: int
    plot: PlotSchema

In [10]:
with open("config_a.yaml") as f:
    raw_config = yaml.safe_load(f)
    config_version = raw_config.pop("version", 1)
    if config_version == 1:
        # run evolution
        raw_config["plot"] = {
            "kind": "scatter",
            "x": raw_config.pop("plot_x"),
            "y": raw_config.pop("plot_y")
        }
    config = cattr.structure(raw_config, ConfigSchema)

In [11]:
if config.dataset == Dataset.LINNERUD:
    data = load_linnerud(as_frame=True).data
elif config.dataset == Dataset.IRIS:
    data = load_iris(as_frame=True).data
else:
    raise ValueError(f"Unsupported dataset {config.dataset}")

data["Outlier"] = (
    LocalOutlierFactor(config.outlier_n)
    .fit_predict(data) == -1
)

if isinstance(config.plot, ScatterPlotSchema):
    if config.plot.z is None:
        chart = (
            altair.Chart(data)
            .mark_point()
            .encode(x=config.plot.x, y=config.plot.y, color="Outlier")
        )
    else:
        chart = plt.axes(projection="3d")
        chart.scatter3D(
            data[config.plot.x],
            data[config.plot.y],
            data[config.plot.z],
            c=data["Outlier"],
        )
else:
    chart = (
        altair.Chart(data)
        .mark_rect()
        .encode(
            x=altair.X(config.plot.x, bin=True),
            y=altair.Y(config.plot.y, bin=True),
            color="Outlier",
            opacity="count()",
        )
    )

chart