<a href="https://www.kaggle.com/code/yno3fm36xqnnc8/logistic-regression-the-jax-way?scriptVersionId=139166041" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

## Logistic Regression the JAX Way

In [None]:
import pandas as pd
import jax
import jax.numpy as jnp
from collections import namedtuple

The goal of this notebook is to demonstrate how to do logistic regression with the JAX library. I'm sure there are other, better ways to do this, but this is a good start. For this example, let's use the Titanic dataset found on Kaggle. 

In [None]:
test_data = pd.read_csv("/kaggle/input/titanic/test.csv")
train_data = pd.read_csv("/kaggle/input/titanic/train.csv")

pclasses = train_data['Pclass']
pclasses = jnp.array(pclasses).reshape((-1, 1))

survived = train_data['Survived']
survived = jnp.array(survived).reshape((-1, 1))

Our logistic regression model requires two parameters, the weight $w$ and the bias $b$. We will regress over a single feature: passenger class number.

In [None]:
LogisticRegressionParams = namedtuple('LogisticRegressionParams', 'w b')
model_params = LogisticRegressionParams(jnp.zeros([1, 2]), jnp.zeros([2]))

First we define the `predict` function. It takes the model parameters `params` and some regressors `x`, and uses them to create a single prediction. The model implemented here uses the `softmax` function to determine a probability distribution over the two possible states, $0$ (dead) and $1$ (living). Specifically, the value $${z} = {w}{x} + {b}$$ is computed. In our case, this leaves us with a two-dimensional vector which is fed to the softmax function, mapping $(u,v)$ to $(\frac{\exp{u}}{\exp{u} +\exp{v}}, \frac{\exp{v}}{\exp{u} +\exp{v}})$.

Observe the `@jax.jit` decorator. This tells JAX to just-in-time compile our prediction function. Not all functions can be jitted, see [this](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html#why-can-t-we-just-jit-everything) for more.

In [None]:
@jax.jit
def predict(params: LogisticRegressionParams, x: jnp.array):
    z = params.w.transpose() @ x + params.b
    return jax.nn.softmax(z)

The `vpredict` function shows how we can vectorize functions with JAX, allowing us to compute predictions in batches.

In [None]:
@jax.jit
def vpredict(params, regressors):
    f = jax.vmap(lambda x: predict(params, x))
    return f(regressors)

The `nll` function computes the negative log-likelihood of the data as a function of `params`. That is, it computes the probability of the observed data given the model. The `take_along_axis` function is used to index the predictions, retrieving the probability of the particular occurence. Finally the mean is taken across the whole batch.

In [None]:
@jax.jit
def nll(params: LogisticRegressionParams, regressors: jnp.array, labels: jnp.array):
    probs = vpredict(params, regressors)
    log_probs = jnp.log(probs)
    return -jnp.take_along_axis(log_probs, labels, 1).mean()

Now it's time to train the model. This is accomplished by gradient descent. The gradient of the negative log-likelihood function `nll` is determined using `jax.grad`. The gradient over the entire training set is computed and the model parameters are updated according to the rule
$$
p \leftarrow p - \eta \nabla{\ell},
$$
where $\eta$ is the learning rate (set here to $0.01$).

In [None]:
learning_rate = 1e-2
loss_grad_fn = jax.grad(nll)
for i in range(1_000):
    if i % 100 == 0:
        print(nll(model_params, pclasses, survived))
    grads = loss_grad_fn(model_params, pclasses, survived)
    model_params = LogisticRegressionParams(model_params.w - learning_rate * grads[0], model_params.b - learning_rate * grads[1])

Now that the model is trained, let's see how well it performs.

In [None]:
f = jax.vmap(lambda x: predict(model_params, x))
preds = f(pclasses).argmax(axis=1).reshape((-1, 1))
accuracy = 1.0 - abs(preds-survived).mean()
print(f"accuracy is {100.0*accuracy:.2f}%")

Finally, we generate predictions and save it to `/kaggle/working/submission.csv`.

In [None]:
pclasses = test_data['Pclass']
pclasses = jnp.array(pclasses).reshape((-1, 1))
f = jax.vmap(lambda x: predict(model_params, x))
test_preds = f(pclasses).argmax(axis=1).reshape((-1, 1))
test_data['Survived'] = test_preds
test_data['Survived'] = test_data['Survived'].apply(lambda x: int(x))
test_data[['PassengerId', 'Survived']].set_index("PassengerId").to_csv("/kaggle/working/submission.csv")
test_data