# パーセプトロン基準を損失とする単層パーセプトロンの JAX 実装
* Google Colab 上で動作を確認．
* Iris は先頭の 50 個のデータが残りの 100 個から線形分離可能なデータセットです．
* 問題のサイズが小さいのでJAX 使わずに NumPy で実装したほうが速く動作しました．JIT コンパイル時間抜きにしても２倍程度かかっています．GPU で動かすとさらに遅くなります（たぶんデータ転送オーバーヘッド）．
* NumPy で実装した場合と訓練後の重みが若干違います．float 32 のせいかなと思いましたがそこまで影響はないようです．謎です．

In [None]:
!pip install flax optax

In [None]:
import jax
from functools import partial
import jax.numpy as jnp
from flax import linen as nn


class PerceptronClassifier(nn.Module):
    features: int = 1
    use_bias: bool = False

    @nn.compact
    def __call__(self, X):
        X = nn.Dense(
            self.features, use_bias=self.use_bias, kernel_init=jax.nn.initializers.zeros
        )(X)
        return X


@jax.jit
def step(idx, state):
    def loss_fn(params, X, y):
        y_pred = state.apply_fn({"params": params}, X)
        preloss = -y_pred * y
        loss = (preloss + jnp.abs(preloss)) / 2.0
        return loss[0]

    X, y = X_train[idx], y_train[idx]
    loss, grads = jax.value_and_grad(loss_fn)(state.params, X, y)
    state = state.apply_gradients(grads=grads)
    return state

In [None]:
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import optax
from flax.training.train_state import TrainState

iris_dataset = load_iris()
iris_dataset["target"][0:50] = 1
iris_dataset["target"][50:] = -1
X_train, X_test, y_train, y_test = train_test_split(
    iris_dataset["data"], iris_dataset["target"], test_size=0.25
)
X_train, X_test, y_train, y_test = (
    jax.device_put(X_train),
    jax.device_put(X_test),
    jax.device_put(y_train),
    jax.device_put(y_test),
)


perceptron = PerceptronClassifier()
params = perceptron.init(jax.random.PRNGKey(0), jnp.empty(4))["params"]
lr = 1.0
tx = optax.scale(-2.0 * lr)
state = TrainState.create(apply_fn=perceptron.apply, params=params, tx=tx)

# training
num_epochs = 3
for epoch_idx in range(num_epochs):
    state = jax.lax.fori_loop(0, X_train.shape[0], step, state)

# test
y_pred = jnp.ravel(
    jnp.where(state.apply_fn({"params": state.params}, X_test) >= 0, 1, -1)
)
acc = jnp.mean(jnp.where(y_test - y_pred, 0, 1))
print("Accuracy =", acc)