# Tutorial #2: Logistic Regression

In this tutorial, we'll explain logistic regression and use it for classification of a few simple datasets. We'll begin by performing logistic regression analytically using a synthetic dataset for a single scalar variable and binary outcomes, then we'll use `scikit-learn` and `flax` to perform logistic regression on a multivariate dataset with binary outcomes.

Assume that the output variable $y \in {0, 1}$ is a binary variable, while the input variable $\boldsymbol{x} \in \mathbb{R}$. The assumption in a linear regression model is that the output data $y=1$ with probability $\theta$ and $y=0$ with probability $1-\theta$. Thus, $$p(y|\theta) = \theta^y(1-\theta)^{1-y}.$$ The probability $\theta$ must be between 0 and 1. To ensure that $\theta$ falls between 0 and 1, we set $\theta = \sigma(\boldsymbol{z})$ where $\sigma$ is the logistic function $$\sigma(z) = \frac{e^z}{1+e^z} = \frac{1}{1+e^{-z}}.$$ Note that $\sigma(z \rightarrow \infty) = 1$ and $\sigma(z \rightarrow -\infty) = 0$. While non-linear models can be used for $z$, the simplest assumption is that $z$ depends linearly on the inputs $\boldsymbol{x}$, so that $z = \boldsymbol{w}^T \boldsymbol{x}$. This gives $$p(y|\boldsymbol{x}, \boldsymbol{w}) = \sigma(\boldsymbol{w}^T \boldsymbol{x})^y(1-\sigma(\boldsymbol{w}^T \boldsymbol{x}))^{1-y}.$$ We then maximize the log probability of the data, which is equivalent to minimizing a loss function given by the negative log probability of the data. For a dataset with $N$ examples $\{\boldsymbol{x}_{i}, {y}_i\}_{i=1}^N$, the logistic regression loss function is thus $$\boldsymbol{w} = \arg \min \sum_{i=1}^N -y_i \log{\sigma(\boldsymbol{w}^T\boldsymbol{x}_i)} -(1-y_i)\log{(1-{\sigma(\boldsymbol{w}^T\boldsymbol{x})})}$$ We can generalize logistic regression from the binary classification case to the $K$-class classification case by replacing the Logistic function with the Softmax function $$\sigma(\boldsymbol{z})_i = \frac{e^{z_i}}{\sum_{j=1}^K e^{z_j}}$$

Useful references on Logistic Regression:
* [COS324](https://www.cs.princeton.edu/courses/archive/spring19/cos324/) Lecture Notes, [Logistic Regression](https://www.cs.princeton.edu/courses/archive/spring19/cos324/files/logistic-regression.pdf)
* [Probabilistic Machine Learning Book](https://probml.github.io/pml-book/book1.html) (Kevin Murphy), Chapter 10 

In [None]:
import jax.numpy as jnp
import jax.random as random
import jax
import matplotlib.pyplot as plt

### 2.1: Analytic Logistic Regression using `jax.numpy` and synthetic data

We'll first generate some synthetic data. We'll have two classes which are not linearly separable.

In [None]:
key = random.key(0)
N_data = 40

def class_1(key, N):
    return -2.5 + random.normal(key, (N,))

def class_2(key, N):
    return 0.5 + random.normal(key, (N,))

key1, key2 = random.split(key)

x_1 = class_1(key1, N_data//2)
x_2 = class_2(key2, N_data//2)

y_1 = jnp.zeros(x_1.shape)
y_2 = jnp.ones(x_2.shape)

plt.scatter(x_1, y_1, color='blue', marker='x')
plt.scatter(x_2, y_2, color='red', marker='x')
plt.show()

We'll then assume that our data is given by a logistic regression model, with $p(y = 1 | x) = \sigma(z)$ with $z = w x + b$.

In [None]:
X = jnp.concatenate([x_1, x_2])
Y = jnp.concatenate([y_1, y_2])

def logistic(z):
    return 1 / (1 + jnp.exp(-z))

def loss_function(x, y, w, b):
    z = x * w + b
    return -y * jnp.log(logistic(z)) - (1-y) * jnp.log(1 - logistic(z))

loss_grad_fn = jax.value_and_grad(jax.jit(lambda w, b: jnp.mean(jax.vmap(loss_function, in_axes=(0,0,None,None), out_axes=0)(X, Y, w, b))), argnums=(0, 1))

In [None]:
key, subkey = random.split(key)
w_init = random.normal(subkey)
b_init = 0.0
N_train = 100000
lr = 1e-3
losses = []

w = w_init
b = b_init
for _ in range(N_train):
    loss, grads = loss_grad_fn(w, b)
    losses.append(loss)
    w = w - lr * grads[0]
    b = b - lr * grads[1]
print(w)
print(b)

In [None]:
plt.plot(losses)
plt.show()

In [None]:
plt.scatter(x_1, y_1, color='blue', marker='x',label='class 0')
plt.scatter(x_2, y_2, color='red', marker='x', label='class 1')
x_plot = jnp.linspace(-5.0, 3.0, 1000)
plt.plot(x_plot, logistic(x_plot * w + b), color='green',label='Logistic Classifier')
plt.legend()
plt.show()

### 3.2: Scikit-learn for multivariate dataset

We can read in the data (a breast cancer diagnostic dataset) into a pandas dataframe. We'll then separate the data into training and testing splits, with 80% of the data in the training set and 20% in the testing set. Since our data is made up of data 568 individuals, with no temporal state, we can split the training and testing data by sampling randomly without introducing data leakage.

In [None]:
import pandas as pd

df = pd.read_csv('datasets/breast_cancer_wisconsin/wdbc.data')
dataset = df.to_numpy()
X = jnp.asarray(dataset[:,2:].astype(float))
y = jnp.asarray((dataset[:,1] == 'B').astype(int))

key = random.key(0)
key, subkey = random.split(key)
X_shuffled = random.permutation(subkey, X, axis=0)
y_shuffled = random.permutation(subkey, y, axis=0)

N_div = int(0.8 * X.shape[0])
X_train = X_shuffled[:N_div]
y_train = y_shuffled[:N_div]
X_test = X_shuffled[N_div:]
y_test = y_shuffled[N_div:]

print(X_train.shape)
print(y_train.shape)
print(X_test.shape)
print(y_test.shape)

Next we can use sklearn to perform logistic regression. With the `.fit()` function we can train a model, and with the `.score()` function we can evaluate performance on the testing dataset.

In [None]:
from sklearn.linear_model import LogisticRegression

model = LogisticRegression(penalty='l2', max_iter=10000)
trained_model = model.fit(X_train, y_train)

print(trained_model.coef_)
print(trained_model.intercept_)

By the way, we could have also imported the dataset from the `sklearn` library, using `from sklearn.datasets import load_breast_cancer` and
`data = load_breast_cancer()` with `X = data.data` and `y = data.target`.

We'll now evaluate model performance using the test dataset.

In [None]:
trained_model.score(X_test, y_test)

We get about 93.8% accuracy.

### 3.3: `flax` for multivariate binary classification with logistic regression

We'll first write a logistic regression model in `flax`.

In [None]:
from flax import nnx
import optax

class LogisticRegression(nnx.Module):
    def __init__(self, din: int, rngs: nnx.Rngs):
        self.linear = nnx.Linear(din, 1, rngs=rngs)

    def logistic(z):
        return 1 / (1 + jnp.exp(-z))

    def __call__(self, x):
        return logistic(self.linear(x))

In [None]:
rngs = nnx.Rngs(0)
model = LogisticRegression(X.shape[1], rngs=rngs)
optimizer = nnx.Optimizer(model, optax.sgd(1e-3))

In [None]:
@nnx.jit
def train_step(model, optimizer, X_train, y_train):
    def loss_fn(model):
        theta = nnx.vmap(model)(X_train)[:,0]
        print(theta.shape)
        print(y_train.shape)
        return jnp.mean(- y_train * jnp.log(theta) - (1 - y_train) * jnp.log(1 - theta))

    loss, grads = nnx.value_and_grad(loss_fn)(model)
    optimizer.update(grads)

    return loss

In [None]:
loss = train_step(model, optimizer, X_train, y_train)
print(loss)

Oh no! Our loss function is giving us `NaN`s. The culprit is the `logistic` function, which gives `NaN` for large values of `z`. We'll have to find a way to eliminate `NaN`s from the loss function.

#### 3.3.1: Normalize data

While we could use clever numerical tricks to prevent `Nan` and `inf` within the loss function, a simpler approach would be to normalize the dataset. For each feature (in both the training and testing sets), we'll subtract by the mean of the training data and divide by the standard deviation of the training data.

In [None]:
means = jnp.mean(X_train, axis=0)
stds = jnp.std(X_train, axis=0)

X_train_normalized = (X_train - means) / stds
X_test_normalized = (X_test - means) / stds

Now let's try computing the loss function with the normalized data.

In [None]:
rngs = nnx.Rngs(0)
model = LogisticRegression(X.shape[1], rngs=rngs)
optimizer = nnx.Optimizer(model, optax.sgd(1e-3))

In [None]:
loss = train_step(model, optimizer, X_train_normalized, y_train)
print(loss)
loss = train_step(model, optimizer, X_train_normalized, y_train)
print(loss)

In [None]:
losses = []
N_train = 50000
for _ in range(N_train):
    loss = train_step(model, optimizer, X_train_normalized, y_train)
    losses.append(loss)

In [None]:
plt.plot(losses)
plt.show()

We'll now evaluate the performance of the trained model.

In [None]:
theta_test = nnx.vmap(model)(X_test_normalized)[:,0]
y_test_eval = (theta_test > 0.5).astype(int)
accuracy = jnp.mean((y_test == y_test_eval).astype(int))
print(accuracy)

We get 96.5% accuracy on the test dataset.