<a href="https://colab.research.google.com/github/bkestelman/jax-ml-tutorial/blob/master/logistic_regression.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import jax
import jax.numpy as np
from jax import grad
import numpy as onp # original numpy
from random import randint

In [0]:
def forward(x, W, b):
  return W.dot(x) + b

In [0]:
def loss(x, target, W, b):
  return np.sum( (forward(x, W, b) - target)**2 )

In [0]:
def backprop(x, target, W, b, learning_rate):
  W -= grad(loss, argnums=2)(x, target, W, b) * learning_rate
  b -= grad(loss, argnums=3)(x, target, W, b) * learning_rate
  return W, b

In [0]:
def init_params(input_size, output_size):
  W = onp.random.rand(output_size, input_size)
  b = 0.0
  return W, b

In [92]:
W, b = init_params(2, 1)
print(W, b)

[[0.5353666  0.96355464]] 0.0


In [0]:
def train(X, labels, W, b, learning_rate):
  for x, label in zip(X, labels):
    W, b = backprop(x, label, W, b, learning_rate)
  return W, b

In [0]:
def test(X, labels, W, b):
  correct = 0
  for x, label in zip(X, labels):
    raw_pred = forward(x, W, b)
    pred = 0 if raw_pred < 0.5 else 1
    if pred == label:
      correct += 1
  print('Accuracy:', correct / len(X))

In [0]:
N = 200
X = [ onp.random.randint(2, size=2) for _ in range(N) ]
labels = [ np.array([float(X[0] and X[1])]) for X in X ]

train_test_split = int(0.7 * N)
train_X, train_labels = X[:train_test_split], labels[:train_test_split]
test_X, test_labels = X[train_test_split:], labels[train_test_split:]

In [94]:
test(test_X, test_labels, W, b)

Accuracy: 0.4666666666666667


In [95]:
W, b = train(train_X, train_labels, W, b, learning_rate=0.01)
print(W, b)

[[0.4043496  0.68912023]] -0.26370358


In [96]:
test(test_X, test_labels, W, b)

Accuracy: 1.0
