# Softmax Regression

In [None]:
%pip install \
    git+https://github.com/deepmind/dm-haiku@v0.0.4 \
    git+https://github.com/deepmind/optax@v0.0.9

In [2]:
import haiku as hk

from jax import jit, partial, vmap, grad
from jax import random
import jax.lax as lax
import jax.nn as nn
import jax.numpy as np
from jax.scipy.special import logit

import optax

In [3]:
y_size = 4

In [4]:
@hk.without_apply_rng
@hk.transform
def model(x):
    linear = hk.Linear(output_size=y_size)
    return linear(x)

In [5]:
rng = random.PRNGKey(42)

In [6]:
x_shape = (10,)
rng, r = random.split(rng)
generating_model_state = model.init(r, np.zeros(x_shape))
rng, r = random.split(rng)
x = random.normal(r, (1000,) + x_shape)
y = vmap(partial(model.apply, generating_model_state))(x)
y = np.argmax(y, axis=-1)
y_one_hot = nn.one_hot(y, y_size)

@jit
def loss_fn(model_state):
    model_predictions_logits = vmap(partial(model.apply, model_state))(x)
    loss = optax.softmax_cross_entropy(model_predictions_logits, y_one_hot)
    loss = np.mean(loss)
    return loss

@jit
def accuracy(model_state):
    model_predictions_logits = vmap(partial(model.apply, model_state))(x)
    model_predictions = np.argmax(model_predictions_logits, axis=-1)
    return np.mean(y == model_predictions)

In [7]:
steps = 50
start_learning_rate = 1e-2
optimizer = optax.adamw(start_learning_rate, weight_decay=1e-2)

@jit
def train(model_state, optimizer_state):
    def train_step(i, train_state):
        model_state, optimizer_state = train_state
        loss_grads = grad(loss_fn)(model_state)
        model_updates, optimizer_state = optimizer.update(loss_grads, optimizer_state, model_state)
        model_state = optax.apply_updates(model_state, model_updates)
        return model_state, optimizer_state

    initial_train_state = model_state, optimizer_state
    return lax.fori_loop(0, steps, train_step, initial_train_state)

In [8]:
rng, r = random.split(rng)
inferred_model_state = model.init(r, np.zeros(x.shape[1:]))
optimizer_state = optimizer.init(inferred_model_state)

In [9]:
for i in range(10):
    print(accuracy(inferred_model_state))
    inferred_model_state, optimizer_state = train(inferred_model_state, optimizer_state)
print(accuracy(inferred_model_state))

0.20500001
0.74700004
0.887
0.92100006
0.933
0.94100004
0.952
0.957
0.96000004
0.9620001
0.96400005
