<a href="https://colab.research.google.com/github/bkestelman/jax-ml-tutorial/blob/master/logistic_regression.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

The main function we want from JAX is jax.grad(), which will automatically calculate gradients for us.

We also import jax.numpy, which wraps numpy functions so that jax can calculate their derivatives while letting us use good ol' familiar numpy. 

We import regular numpy as onp (original numpy). We don't really need this, but it is convenient in a few special cases (like random numbers). We discuss this more later.

In [0]:
import jax
import jax.numpy as np
from jax import grad
import numpy as onp # original numpy
from random import randint

Simple logistic regression (observe how we use jax.grad for back-propagation!)

In [0]:
def forward(x, W, b):
  return W.dot(x) + b

In [0]:
def loss(x, target, W, b):
  return np.sum( (forward(x, W, b) - target)**2 )

In [0]:
def backprop(x, target, W, b, learning_rate):
  W -= grad(loss, argnums=2)(x, target, W, b) * learning_rate
  b -= grad(loss, argnums=3)(x, target, W, b) * learning_rate
  return W, b

Initialize parameters (weight matrix and bias)

In [0]:
def init_params(input_size, output_size):
  W = onp.random.rand(output_size, input_size)
  b = 0.0
  return W, b

In [7]:
W, b = init_params(2, 1)
print(W, b)

[[0.94147518 0.32437763]] 0.0


Define how to train and test the model

In [0]:
def train(X, labels, W, b, learning_rate):
  for x, label in zip(X, labels):
    W, b = backprop(x, label, W, b, learning_rate)
  return W, b

In [0]:
def test(X, labels, W, b):
  correct = 0
  for x, label in zip(X, labels):
    raw_pred = forward(x, W, b)
    pred = 0 if raw_pred < 0.5 else 1
    if pred == label:
      correct += 1
  print('Accuracy:', correct / len(X))

We will test if our logistic regression model can learn to solve boolean AND

In [10]:
N = 200
X = [ onp.random.randint(2, size=2) for _ in range(N) ]
labels = [ np.array([float(X[0] and X[1])]) for X in X ]

train_test_split = int(0.7 * N)
train_X, train_labels = X[:train_test_split], labels[:train_test_split]
test_X, test_labels = X[train_test_split:], labels[train_test_split:]



First we test how it does without any training. This result will vary depending on the initial parameters. 

In [11]:
test(test_X, test_labels, W, b)

Accuracy: 0.6666666666666666


Train the model and show the weights and bias after training

In [12]:
W, b = train(train_X, train_labels, W, b, learning_rate=0.01)
print(W, b)

[[0.6205038 0.3154657]] -0.24448743


Test again after training

In [13]:
test(test_X, test_labels, W, b)

Accuracy: 1.0
