In [None]:
%pylab inline
plt.style.use("bmh")
plt.rcParams["figure.figsize"] = (6,6)

To install JAX on CPU, use:
    
```
pip3 install jax[cpu]
```

To install JAX with GPU support, refer to [documentation](https://github.com/google/jax#installation).

In [None]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

In [None]:
import numpy as np
from sklearn.datasets import make_blobs

# Creating the data

In [None]:
X, y = make_blobs(1000, centers=[[-3, -3], [0, 0]], cluster_std=1.25)

In [None]:
plt.figure(figsize=(7,6))
plt.scatter(X[:, 0], X[:, 1], c=y, alpha=0.6, edgecolor='k',
            cmap=plt.cm.coolwarm, vmin=0, vmax=1)
plt.colorbar()
plt.xlabel('$x_0$', fontsize=14)
plt.ylabel('$x_1$', fontsize=14)
plt.title("Actual targets", fontsize=12)
plt.tight_layout()

In [None]:
Xj = jnp.array(X)
yj = jnp.array(y)

# Basic operations

Random numbers generation:

In [None]:
key = random.PRNGKey(0)

W = random.normal(key, (2, 1))
b = random.normal(key, (1, 1))

In [None]:
W, b

Linear algebra:

In [None]:
jnp.dot(Xj, W)

Logistic regression functions:

In [None]:
def sigmoid(a):
    """Sigmoid activation function."""

    return 1/(1 + jnp.exp(-a))

def regressor(x, w, b):
    """Full logistic regression expression."""
    return sigmoid(jnp.dot(x, w) + b)

In [None]:
y_pred = regressor(Xj, W, b)

plt.figure(figsize=(14,6))

plt.subplot(1, 2, 1)
plt.scatter(X[:, 0], X[:, 1], c=y_pred.flatten(), alpha=0.6, edgecolor='k',
            cmap=plt.cm.coolwarm, vmin=0, vmax=1)
plt.colorbar()
plt.xlabel('$x_0$', fontsize=14)
plt.ylabel('$x_1$', fontsize=14)
plt.plot([0, W[0, 0].item()],[0, W[1,0].item()], "-", c="forestgreen", linewidth=4)
plt.title("Predicted targets", fontsize=12)

plt.subplot(1, 2, 2)
plt.scatter(X[:, 0], X[:, 1], c=y, alpha=0.6, edgecolor='k',
            cmap=plt.cm.coolwarm, vmin=0, vmax=1)
plt.colorbar()
plt.xlabel('$x_0$', fontsize=14)
plt.ylabel('$x_1$', fontsize=14)
plt.title("Actual targets", fontsize=12)
plt.tight_layout()

JIT compilation:

In [None]:
regressor_jit = jit(regressor)
regressor_jit(Xj, W, b)

Call are async by default (doesn't matter that much for CPU, but is extremely important for GPU operations):

In [None]:
%timeit regressor_jit(Xj, W, b).block_until_ready()

In [None]:
%timeit regressor(Xj, W, b).block_until_ready()

In [None]:
%timeit np.dot(X, W.to_py())

# Vectorization

Wouldn't it be great to have automatic vectorization?

In [None]:
Wf = random.normal(key, (2,))
bf = random.normal(key)

In [None]:
Wf, bf

Single example operation:

In [None]:
(Xj[0] * Wf).sum() + bf

Logistic regression function applicable to a single element:

In [None]:
def sigmoid(a):
    """Sigmoid activation function."""

    return 1/(1 + jnp.exp(-a))

def regressor(x, w, b):
    """Full logistic regression expression."""
    return sigmoid(jnp.sum(x * w) + b)

In [None]:
regressor(Xj[0], Wf, bf)

Vectorization is done with `vmap`:

In [None]:
regressor_batch = vmap(regressor, in_axes=(0, None, None))

In [None]:
regressor_batch(Xj, Wf, bf)

In [None]:
%timeit -n 10 -r 3 regressor_batch(Xj, Wf, bf).block_until_ready()

JIT is composable:

In [None]:
regressor_batch_jit = jit(regressor_batch)
regressor_batch_jit(Xj, Wf, bf)

In [None]:
%timeit regressor_batch_jit(Xj, Wf, bf).block_until_ready()

In [None]:
%timeit regressor_batch(Xj, Wf, bf).block_until_ready()

# Autodiff

Autodiff in JAX is functional, and is applied explicitly:

In [None]:
sigmoid_grad = grad(sigmoid)

In [None]:
sigmoid_grad(0.1), sigmoid(0.1)

In [None]:
sigmoid_grad(0.1), sigmoid(0.1) * (1 - sigmoid(0.1))

In [None]:
def f(x, y):
    return sigmoid(x) * sigmoid(2 * y)

In [None]:
XVAL = 0.1
YVAL = 0.1

f_grad = grad(f, argnums=(0, 1))
f_grad(XVAL, YVAL)

In [None]:
sigmoid(XVAL) * sigmoid(2 * YVAL) * (1 - sigmoid(XVAL))

In [None]:
2 * sigmoid(XVAL) * sigmoid(2 * YVAL) * (1 - sigmoid(2 * YVAL))

In [None]:
@jit
def loss(X, Y, w, b):
    """Loss function suitable for JAX autodiff."""

    y_pred = regressor_batch(X, w, b)
    return -jnp.mean(Y * jnp.log(y_pred) + (1 - Y) * jnp.log(1 - y_pred))

In [None]:
loss_grad = jit(grad(loss, argnums=(2,3)))

In [None]:
Wf = random.normal(key, (2,))
bf = random.normal(key)

In [None]:
loss_grad(Xj, yj, Wf, bf)

In [None]:
EPOCHS = 1000
LR = 1e-1
DELTA = 0.00001
loss_history = []

for i in range(EPOCHS):
    current_loss = loss(Xj, yj, Wf, bf)
    w_grad, b_grad = loss_grad(Xj, yj, Wf, bf)
    loss_history.append(current_loss.to_py())

    Wf = Wf - w_grad * LR
    bf = bf - b_grad * LR

    if i % 20 == 0:
        print(f"Epoch {i}: loss = {loss_history[-1]}")
    
    try:
        if loss_history[-2] - loss_history[-1] < DELTA:
            break
    except:
        pass

In [None]:
y_pred = regressor_batch(Xj, Wf, bf)

plt.figure(figsize=(14,6))

plt.subplot(1, 2, 1)
plt.scatter(X[:, 0], X[:, 1], c=y_pred.flatten(), alpha=0.6, edgecolor='k',
            cmap=plt.cm.coolwarm, vmin=0, vmax=1)
plt.colorbar()
plt.xlabel('$x_0$', fontsize=14)
plt.ylabel('$x_1$', fontsize=14)
plt.plot([0, Wf[0].item()],
         [0, W[1].item()],
         "-",
         c="forestgreen",
         linewidth=4)
plt.title("Predicted targets", fontsize=12)

plt.subplot(1, 2, 2)
plt.scatter(X[:, 0], X[:, 1], c=y, alpha=0.6, edgecolor='k',
            cmap=plt.cm.coolwarm, vmin=0, vmax=1)
plt.colorbar()
plt.xlabel('$x_0$', fontsize=14)
plt.ylabel('$x_1$', fontsize=14)
plt.title("Actual targets", fontsize=12)
plt.tight_layout()

In [None]:
Wf, bf

In [None]:
plt.figure(figsize=(6,6))
plt.plot(loss_history)
plt.xlabel("epoch")
plt.ylabel("loss")

# Classification metrics

In [None]:
from sklearn.metrics import classification_report

In [None]:
y_class = (y_pred >= 0.5).astype(int)

In [None]:
y_class

In [None]:
print(classification_report(y, y_class))

In [None]:
plt.hist(y_pred[y==0], range=(0,1))
plt.hist(y_pred[y==1], range=(0,1));

In [None]:
jnp.mean(y_class == y)