### 前準備


In [None]:
from typing import Self, Protocol

import ibis
import jax.numpy as jnp
import jaxtyping as jnpt
import pandas as pl
import plotly.express as px
import plotly.graph_objects as go

In [None]:
ibis.set_backend("polars")
ibis.options.interactive = True

### テスト関数


In [None]:
class Classifier[N: int, P: int](Protocol):
    learning_rate: float
    n_epochs: int
    loss_by_epochs: list[float]

    def fit(self, X: jnpt.Float32[jnpt.Array, "N P"], y: jnpt.Int32[jnpt.Array, "N 1"]) -> Self:
        ...

    def predict(self, X: jnpt.Float32[jnpt.Array, "N P"]) -> jnpt.Int32[jnpt.Array, "N"]:
        ...

In [None]:
def show_fig(fig: go.Figure, title: str) -> None:
    fig.update_layout(title=title, height=500, width=700)
    fig.show()

In [None]:
def plot_classifier_loss[N: int, P: int](trained_classifier: Classifier[N, P]) -> None:
    misclassification_df: pl.DataFrame = ibis.memtable(
        {
            "Epochs": list(range(1, len(trained_classifier.loss_by_epochs) + 1)),
            "Loss": trained_classifier.loss_by_epochs,
        }
    ).execute()

    show_fig(
        fig=px.line(misclassification_df, x="Epochs", y="Loss", markers=True),
        title=f"{type(trained_classifier).__name__} - Learning rate {trained_classifier.learning_rate}, Epochs {trained_classifier.n_epochs}",
    )

In [None]:
def plot_decision_regions[N: int, P: int](  # TODO: Refactoring
    X: jnpt.Float[jnpt.Array, "N P"],
    y: jnpt.Int[jnpt.Array, "1 P"],
    classifier: Classifier[N, P],
    resolution: float = 0.02,
) -> None:
    x1_min, x1_max = X[:, 0].min() - 1, X[:, 0].max() + 1
    x2_min, x2_max = X[:, 1].min() - 1, X[:, 1].max() + 1
    xx1, xx2 = jnp.meshgrid(jnp.arange(x1_min, x1_max, resolution), jnp.arange(x2_min, x2_max, resolution))
    Z = classifier.predict(jnp.array([xx1.ravel(), xx2.ravel()]).T)
    Z = Z.reshape(xx1.shape)

    contour = go.Contour(x=xx1[0], y=xx2[:, 0], z=Z, showscale=False, colorscale="Viridis")

    scatter = go.Scatter(
        x=X[:, 0],
        y=X[:, 1],
        mode="markers",
        marker={"color": y, "colorscale": "Viridis", "line_width": 1},
        showlegend=False,
    )

    fig = go.Figure(data=[contour, scatter])
    fig.update_layout(
        title="Decision regions",
        xaxis_title="sepal length [cm]",
        yaxis_title="petal length [cm]",
        legend_title="Classes",
        margin={"l": 50, "r": 50, "b": 100, "t": 100, "pad": 4},
        height=500,
        width=700,
    )
    fig.show()