### Import Required Packages and Set Options

In [10]:
import os
import sys

import numpy as np
import pandas as pd
import multiprocessing as mp

import matplotlib.pyplot as plt
import seaborn as sns

from datetime import datetime
from itertools import combinations
from functools import partial

import jax.numpy as jnp
from jax import grad, jit, vmap

from sklearn.datasets import load_boston

np.set_printoptions(precision=4, suppress=True)
jnp.set_printoptions(precision=4, suppress=True)

In [11]:
REPO_ROOT = "/Users/ericlundquist/Repos/rankfm"

### Load Example Data

In [12]:
boston = load_boston()
features = boston.data
target = boston.target

# convert to z-scores for numerical stability and add constant column
features = (features - features.mean(axis=0)) / features.std(axis=0)
features = jnp.hstack([jnp.ones([features.shape[0], 1]), features])

# transform the target to a binary value
target = jnp.where(target > target.mean(), 1, 0)

coeffs = ['CONS'] + list(boston.feature_names)
print(" ".join(coeffs))
print(features.shape)
print(target.shape)

CONS CRIM ZN INDUS CHAS NOX RM AGE DIS RAD TAX PTRATIO B LSTAT
(506, 14)
(506,)




In [13]:
print(features.mean(axis=0))
print(features.std(axis=0))
print(target.mean())

[ 1.  0.  0. -0. -0. -0.  0.  0.  0. -0. -0.  0.  0. -0.]
[0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
0.41304347


### Train a Logistic Regression Model with Mini-Batch SGD Gradient Updates 

In [14]:
class LogReg:
    """train a logistic regression model using SGD"""
    
    def __init__(self, initial_weights, learning_rate=0.1):
        self.learning_rate = learning_rate
        self.weights = initial_weights
    
    @partial(jit, static_argnums=(0,))
    def sigmoid(self, x):
        return 1 / (1 + jnp.exp(-x))

    @partial(jit, static_argnums=(0,))
    def predict_proba(self, weights, inputs):
        return self.sigmoid(jnp.dot(inputs, weights))

    @partial(jit, static_argnums=(0,))
    def loss_value(self, weights, inputs, target):
        predictions = self.predict_proba(weights, inputs)
        loss = -jnp.mean((target * jnp.log(predictions)) + ((1 - target) * jnp.log(1 - predictions)))
        return loss
    
    # auto-diff the loss function wrt all weights
    loss_gradient = grad(loss_value, argnums=[1])
    
    def create_batches(self, inputs, target, batch_size):
        n_batches = np.round(inputs.shape[0] / batch_size)

        batch_inputs = np.array_split(inputs, n_batches)
        batch_target = np.array_split(target, n_batches)
        return list(zip(batch_inputs, batch_target))
    
    def update_weights(self, batches, weights, learning_rate):
        for (b_input, b_target) in batches:
            gradients = self.loss_gradient(weights, b_input, b_target)
            weights -= learning_rate * gradients
        return weights
    
    
    def fit(self, inputs, target, batch_size, epochs):
        
        print("initial loss: {}".format(round(self.loss_value(self.weights, inputs, target), 2)))
        for i in range(epochs):
            batches = self.create_batches(inputs, target, batch_size)
            self.weights = self.update_weights(batches, self.weights, self.learning_rate)
            print("epoch: {} loss: {}".format(i, self.loss_value(self.weights, inputs, target)))
    

#### Full Batches

In [15]:
initial_weights = jnp.zeros(features.shape[1])
logreg = LogReg(initial_weights)

logreg.weights
# logreg.fit(features, target, batch_size=len(target), epochs=10)

DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)

In [17]:
final_weights = mini_batch_train(features, target, initial_weights, batch_size=len(target), lr=0.1, epochs=10)

NameError: name 'mini_batch_train' is not defined

#### Mini-Batches

In [683]:
initial_weights = jnp.zeros(inputs.shape[1])
final_weights = mini_batch_train(features, target, initial_weights, batch_size=10, lr=0.1, epochs=10)

initial loss: 0.6899999976158142
epoch: 0 loss: 0.39493054151535034
epoch: 1 loss: 0.35050806403160095
epoch: 2 loss: 0.3332710564136505
epoch: 3 loss: 0.323830783367157
epoch: 4 loss: 0.3177531063556671
epoch: 5 loss: 0.3134887218475342
epoch: 6 loss: 0.3103281557559967
epoch: 7 loss: 0.3078918159008026
epoch: 8 loss: 0.30595651268959045
epoch: 9 loss: 0.30438217520713806


#### Single-Sample SGD

In [671]:
initial_weights = jnp.zeros(inputs.shape[1])
final_weights = mini_batch_train(features, target, initial_weights, batch_size=1, lr=0.1, epochs=10)

initial loss: 0.6899999976158142
epoch: 0 loss: 0.35765910148620605
epoch: 1 loss: 0.3244864344596863
epoch: 2 loss: 0.3173229694366455
epoch: 3 loss: 0.3147937059402466
epoch: 4 loss: 0.31364718079566956
epoch: 5 loss: 0.31306958198547363
epoch: 6 loss: 0.31278374791145325
epoch: 7 loss: 0.3126581013202667
epoch: 8 loss: 0.31262409687042236
epoch: 9 loss: 0.31264904141426086


### Mess Around with VMAP()

In [28]:
users = [0, 1, 2]
items = [0, 1, 2]
samples = [(0, 0, 1), (1, 2, 1), (2, 0, 2)]

In [29]:
user_weights = np.zeros(3)
item_weights = np.zeros(3)

In [30]:
user_weights, item_weights

(array([0., 0., 0.]), array([0., 0., 0.]))

In [33]:
def update_weights(samples, user_weights, item_weights):
    for u, i, j in samples:
        user_weights[u] += u
        item_weights[i] += i
        item_weights[j] += j
    return user_weights, item_weights
        

In [34]:
user_weights, item_weights = update_weights(samples, user_weights, item_weights)

In [632]:
weights = np.array([1, 2, 3, 4])
matrix = np.arange(16).reshape((4, 4))
batches = np.split(matrix, 4)

In [633]:
matrix

array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11],
       [12, 13, 14, 15]])

In [634]:
batches

[array([[0, 1, 2, 3]]),
 array([[4, 5, 6, 7]]),
 array([[ 8,  9, 10, 11]]),
 array([[12, 13, 14, 15]])]

In [675]:
def running_sum(batches, weights, cursum):
    for batch in batches:
        cursum += jnp.dot(batch, weights)
    return cursum


In [676]:
%%timeit

cursum = 1000
running_sum(batches, weights, cursum)

3.33 ms ± 119 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [677]:
np.sum(np.dot(matrix, weights))

320

In [678]:
vmap_running_sum = vmap(running_sum, in_axes=(0, None, None))

In [679]:
%%timeit

cursum = 1000
vmap_running_sum(batches, weights, cursum)

7.37 ms ± 370 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
