<a href="https://www.kaggle.com/code/yno3fm36xqnnc8/logistic-regression-the-jax-way?scriptVersionId=139646676" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

## Logistic Regression the JAX Way

In [1]:
import pandas as pd
import jax
import jax.numpy as jnp
from collections import namedtuple

The goal of this notebook is to demonstrate how to do logistic regression with the JAX library. I'm sure there are other, better ways to do this, but this is a good start. For this example, let's use the Titanic dataset found on Kaggle.

In [2]:
test_data = pd.read_csv("/kaggle/input/titanic/test.csv")
train_data = pd.read_csv("/kaggle/input/titanic/train.csv")

pclasses = jnp.array(train_data['Pclass']).reshape((-1, 1))
survived = jnp.array(train_data['Survived']).reshape((-1, 1))

Our logistic regression model requires two parameters, the weight $w$ and the bias $b$. We will regress over a single feature: passenger class number.

In [3]:
LogisticRegressionParams = namedtuple('LogisticRegressionParams', 'w b')

First we define the `predict` function. It takes the model parameters `params` and some regressors `x`, and uses them to create a single prediction. This particular model computes $f(x\cdot{w} + b)$, where $f(z) = \frac{1}{1+e^{-z}}$ is a logistic function.

In [4]:
@jax.jit
def sigmoid_predict(params: LogisticRegressionParams, x: jnp.array) -> jnp.array:
    z = x.dot(params.w) + params.b
    return jax.nn.sigmoid(z)

Another model, presented below, uses the `softmax` function to determine a probability distribution over the two possible states, $0$ (dead) and $1$ (living). Specifically, the value $${z} = {w}{x} + {b}$$ is computed. In our case, this leaves us with a two-dimensional vector which is fed to the softmax function, mapping $(u,v)$ to $(\frac{\exp{u}}{\exp{u} +\exp{v}}, \frac{\exp{v}}{\exp{u} +\exp{v}})$.

Observe the `@jax.jit` decorator. This tells JAX to just-in-time compile our prediction function. Not all functions can be jitted, see [this](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html#why-can-t-we-just-jit-everything) for more.

In [5]:
@jax.jit
def softmax_predict(params: LogisticRegressionParams, x: jnp.array) -> jnp.array:
    z = params.w.transpose() @ x + params.b
    return jax.nn.softmax(z)

Below we use `jax.vmap` to vectorize the prediction functions, allowing them to take entire batches of data as input.

In [6]:
sigmoid_vpredict = jax.vmap(sigmoid_predict,(None, 0))
softmax_vpredict = jax.vmap(softmax_predict,(None, 0))

When we train the sigmoid model, we'll use the cross-entropy loss function. This can be thought of as an analogue of "distance" for probability distributions, so our goal is to make the prediction distribution (our model's output) as close as possible to the empirical distribution.

In [7]:
@jax.jit
def crossent(params: jnp.array, features: jnp.array, labels: jnp.array) -> jnp.array:
    predictions = sigmoid_vpredict(params, features)
    a = labels * jnp.log(predictions)
    b = (1.0 - labels) * jnp.log(1.0 - predictions)
    return -jnp.mean(a + b)

It's time to train the logistic model. This is accomplished by gradient descent. The gradient of the cross-entropy loss function is determined using `jax.grad`. The gradient over the entire training set is computed and the model parameters are updated according to the rule
$$
p \leftarrow p - \eta \nabla{\ell},
$$
where $\eta$ is the learning rate (set here to $0.01$) and $\ell$ is the loss function.

In [8]:
learning_rate = 1e-2
sigmoid_params = LogisticRegressionParams(jnp.array(1.0), jnp.array(1.0))
crossent_grad_fn = jax.grad(crossent)

for i in range(1_000):
    if i % 100 == 0:
        print(crossent(sigmoid_params, pclasses, survived))
    grads = crossent_grad_fn(sigmoid_params, pclasses, survived)
    sigmoid_params = LogisticRegressionParams(sigmoid_params.w - learning_rate * grads.w, sigmoid_params.b - learning_rate * grads.b)


2.227012
0.7300361
0.625376
0.6208769
0.620122
0.6195703
0.61905324
0.61856127
0.6180927
0.61764634


Now that the model is trained, let's see how well it performs.

In [9]:
preds = round(sigmoid_vpredict(sigmoid_params, pclasses))
accuracy = 1.0 - abs(preds-survived).mean()
print(f"accuracy is {100.0*accuracy:.2f}%")

accuracy is 67.90%


In this section, the softmax model is trained analogously to the sigmoid model.

When we train the softmax model, we'll use the the `nll` function, which computes the negative log-likelihood of the data as a function of `params`. That is, it computes the probability of the observed data given the model. The `take_along_axis` function is used to index the predictions, retrieving the probability of the particular occurence. Finally the mean is taken across the whole batch.

In [10]:
@jax.jit
def nll(params: LogisticRegressionParams, regressors: jnp.array, labels: jnp.array):
    probs = softmax_vpredict(params, regressors)
    log_probs = jnp.log(probs)
    return -jnp.take_along_axis(log_probs, labels, 1).mean()

In [11]:
learning_rate = 1e-2
softmax_params = LogisticRegressionParams(jnp.zeros([1, 2]), jnp.zeros([2]))
nll_grad_fn = jax.grad(nll)
for i in range(1_000):
    if i % 100 == 0:
        print(nll(softmax_params, pclasses, survived))
    grads = nll_grad_fn(softmax_params, pclasses, survived)
    softmax_params = LogisticRegressionParams(softmax_params.w - learning_rate * grads[0], softmax_params.b - learning_rate * grads[1])

0.6931472
0.6381312
0.6353234
0.63285196
0.6306094
0.6285747
0.62672895
0.6250547
0.6235363
0.62215906


Now let's check the softmax model's accuracy.

In [12]:
preds = softmax_vpredict(softmax_params, pclasses).argmax(axis=1).reshape((-1, 1))
accuracy = 1.0 - abs(preds-survived).mean()
print(f"accuracy is {100.0*accuracy:.2f}%")

accuracy is 67.90%


The predictions made by the two models are compared.

In [13]:
pclasses = jnp.array(test_data['Pclass']).reshape((-1, 1))

test_data['Survived'] = round(sigmoid_vpredict(sigmoid_params, pclasses))
test_data['Survived'] = test_data['Survived'].apply(lambda x: int(x))

test_data['SurvivedSoftmax'] = softmax_vpredict(softmax_params, pclasses).argmax(axis=1)
test_data['SurvivedSoftmax'] = test_data['SurvivedSoftmax'].apply(lambda x: int(x))

test_data

Unnamed: 0,PassengerId,Pclass,Name,Sex,Age,SibSp,Parch,Ticket,Fare,Cabin,Embarked,Survived,SurvivedSoftmax
0,892,3,"Kelly, Mr. James",male,34.5,0,0,330911,7.8292,,Q,0,0
1,893,3,"Wilkes, Mrs. James (Ellen Needs)",female,47.0,1,0,363272,7.0000,,S,0,0
2,894,2,"Myles, Mr. Thomas Francis",male,62.0,0,0,240276,9.6875,,Q,0,0
3,895,3,"Wirz, Mr. Albert",male,27.0,0,0,315154,8.6625,,S,0,0
4,896,3,"Hirvonen, Mrs. Alexander (Helga E Lindqvist)",female,22.0,1,1,3101298,12.2875,,S,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...
413,1305,3,"Spector, Mr. Woolf",male,,0,0,A.5. 3236,8.0500,,S,0,0
414,1306,1,"Oliva y Ocana, Dona. Fermina",female,39.0,0,0,PC 17758,108.9000,C105,C,1,1
415,1307,3,"Saether, Mr. Simon Sivertsen",male,38.5,0,0,SOTON/O.Q. 3101262,7.2500,,S,0,0
416,1308,3,"Ware, Mr. Frederick",male,,0,0,359309,8.0500,,S,0,0


In [14]:
(test_data['Survived'] == test_data['SurvivedSoftmax']).all()

True

This shows that the predictions are identical across models.

Finally, generate predictions and save them to `/kaggle/working/submission.csv`.

In [15]:
test_data[['PassengerId', 'Survived']].set_index("PassengerId").to_csv("/kaggle/working/submission.csv")