In [159]:
import jax
jax.devices()

[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0),
 StreamExecutorGpuDevice(id=1, process_index=0, slice_index=0)]

In [160]:
from sklearn.datasets import fetch_openml
import pickle
import os

In [161]:
def load_mnist():
    pickle_path = '../data/mnist/data.pkl'
    if os.path.exists(pickle_path):
        with open(pickle_path, 'rb') as f:
            return pickle.load(f)
    mnist = fetch_openml(name='mnist_784', version=1, parser="auto")
    with open(pickle_path, 'wb') as f:
        pickle.dump(mnist, f)
    return mnist

mnist = load_mnist()

In [162]:
from random import shuffle
from jax import numpy as jnp
import pandas as pd

all_features = mnist['data']
all_targets = mnist['target']

bool_vec = [i < 60_000 for i in range(len(all_targets))]
shuffle(bool_vec)
split_df = lambda df : (
    df[pd.Series(bool_vec).values],
    df[pd.Series([not b for b in bool_vec]).values]
)

train_features, test_features = split_df(all_features)
train_targets, test_targets = split_df(all_targets)

format_jnp = lambda *dfs : tuple([jnp.asarray(df.to_numpy(), dtype='float32') for df in dfs])

ftr_train, ftr_test, tgt_train, tgt_test = format_jnp(
    train_features,
    test_features,
    pd.get_dummies(train_targets),
    pd.get_dummies(test_targets)
)

normalize = lambda ftr_df : ftr_df / jnp.linalg.norm(ftr_df, axis=1, keepdims=True)
norm_ftr_train, norm_ftr_test = normalize(ftr_train), normalize(ftr_test)

"""
Quick sanity check to make sure that features have been properly normalized
"""

is_normalized = all(jnp.isclose(
    jnp.linalg.norm(norm_ftr_train, axis=1),
    jnp.full(norm_ftr_train.shape[0], 1.0)
))

assert(is_normalized)

In [91]:
from random import randint
from jax import random as jrand

class DistSampler:
    def __init__(self):
        self.key = jrand.PRNGKey(randint(0, 10**6))
        
    def normal(self, *shape):
        self.key, subkey = jrand.split(self.key)
        return jrand.normal(subkey, shape=tuple(shape))
    
    def uniform(self, *shape):
        self.key, subkey = jrand.split(self.key)
        return jrand.uniform(subkey, shape=tuple(shape))
    
dist_sampler = DistSampler()

In [92]:
idx = randint(0, ftr_train.shape[0])
x, y = ftr_train[idx], tgt_train[idx]

img = jnp.reshape(x, (28,28))
for row in img:
    print(''.join([
        '@' if pix > 100 else ' ' for pix in row
    ]))
    
print(f'\n\nAnswer: {y}')

                            
                            
                            
                            
                            
         @@@@@@             
        @@@@@@@@            
        @@@@@@@@@           
          @@@@@@@@          
              @@@@          
              @@@@          
              @@@@          
              @@@           
             @@@@           
           @@@@@@@          
          @@@@@@@@@@        
          @@@@@@@@@@@       
          @@@@@@@@@@@@      
           @@@@@@@@@        
            @@@             
           @@@@             
          @@@@@             
           @@@@             
           @@@@             
           @@@              
                            
                            
                            


Answer: [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]


In [152]:
from jax import grad, nn, jit

num_obs, num_ftrs = ftr_train.shape

def xlayer(x, w, b):
    # x has dims (#In,) and w has dims (#Out, #In) and b has dims (#Out,)
    return jnp.dot(w, x) + b

def compose_layers(activations, index=0):
    afn = activations[index]
    if len(activations) - 1 == index:
        return lambda x, params : afn(xlayer(x, params[2*index], params[2*index+1]))
    else:
        nextfn = compose_layers(activations, index + 1)
        return lambda x, params : nextfn(
            afn(xlayer(x, params[2*index], params[2*index+1])),
            params
        )

def make_nn_and_params(*layers, in_dim=None, sampler=None):
    activations, dims = list(zip(*layers))
    composed_layers = compose_layers(activations)
    nn_func = lambda x, params : nn.softmax(composed_layers(x, params))
    
    def make_params():
        param_list = []
        prev_dim = in_dim
        for d in dims:
            param_list.append(sampler(d, prev_dim)) # Wi
            param_list.append(sampler(d)) # Bi
            prev_dim = d
        return param_list

    return nn_func, make_params

neural_net, make_params = make_nn_and_params(
    (jnp.tanh, 30),
    (jnp.tanh, 50),
    (jnp.tanh, 10),
    #(nn.sigmoid, 30),
    #(nn.sigmoid, 10),
    in_dim = 784,
    sampler = dist_sampler.normal
)

# def neural_net(x, params):
#     w1, b1, w2, b2 = params #, w3, b3 = params
#     t0 = tanh_layer(x, w1, b1)
#     t1 = tanh_layer(t0, w2, b2)
#     #t2 = leaky_relu_layer(t1, w3, b3)
#     return nn.softmax(t1)

# params = [
#     dist_sampler.normal(3000, num_ftrs), # w1
#     dist_sampler.normal(3000), # b1
# #     dist_sampler.uniform(50, 784), # w2
# #     dist_sampler.uniform(50), # b2
#     dist_sampler.normal(10, 3000), # w3
#     dist_sampler.normal(10), # b3
# ]

def cross_entropy(prediction, truth):
    return -truth * jnp.log(prediction) - (1. - truth)*jnp.log(1. - prediction)

# Cross entropy
def loss_fn(params, x, y):
    out = neural_net(x, params)
    cross_entropy_vec = cross_entropy(out, y)
    ce = jnp.sum(cross_entropy_vec)
    return ce

# GOTCHA
# by default, only the first paramter of the input function
# will be differentiated against
loss_gradient = jit(grad(loss_fn))
jit_nn = jit(neural_net)
jit_loss = jit(loss_fn)

In [153]:
from functools import partial
from jax import vmap

def make_loss_grad(params):
    loss_grad = grad(loss_fn)
    partial_loss_grad = partial(loss_grad, params)
    vector_loss_grad = vmap(partial_loss_grad, (0, 0), 0)
    return jit(vector_loss_grad)

In [154]:
from random import sample

Obs, Resp = norm_ftr_train, tgt_train

def random_rows(n, *arrs):
    idxs = sample(list(range(arrs[0].shape[0])), n)
    return [
        jnp.take(x, jnp.asarray(idxs), axis=0) for x in arrs
    ]

def sgd_minibatch(params, batch_size, learning_rate=0.01):
    new_params = [jnp.copy(p) for p in params]
    for _ in range(batch_size):
        idx = randint(0, num_obs - 1)
        x, y = Obs[idx], Resp[idx]
        grads = loss_gradient(params, x, y)
        new_params = [
            param_vec - ((learning_rate / batch_size) * grad)
            for param_vec, grad in zip(new_params, grads)
        ]
    return new_params

def sgd_vector_minibatch(params, batch_size, learning_rate=0.01):
    loss_grad = make_loss_grad(params)
    batch_x, batch_y = random_rows(batch_size, Obs, Resp)
    param_gradients = loss_grad(batch_x, batch_y)
    average_grad_per_param = [
        jnp.average(gradients, 0)
        for gradients in param_gradients
    ]
    updated_params = [
        param - (learning_rate * average_grad)
        for param, average_grad in zip(params, average_grad_per_param)
    ]
    return updated_params 

def avg_training_loss(params):
    loss = 0
    n = 50
    for _ in range(n):
        idx = randint(0, num_obs - 1)
        x, y = norm_ftr_train[idx], tgt_train[idx]
        loss += jit_loss(params, x, y)
    return loss / n

def _max_idx(vec):
    max_x = vec[0]
    max_i = 0
    for i, x in enumerate(vec):
        if vec[i] > max_x:
            max_x = vec[i]
            max_i = i
    return max_i
        
def show_prediction_accuracy(params):
    correct = 0
    n = 1000
    for idx in range(n):
        x, y = norm_ftr_test[idx], tgt_test[idx]
        pred_vec = jit_nn(x, params)
        if _max_idx(pred_vec) == _max_idx(y):
            correct += 1
    print(f'{correct}/{n}')

In [155]:
params = make_params()

In [156]:
# import time

# class Clock():
#     def __init__(self):
#         self.start_time = time.time()
        
#     def start(self):
#         self.start_time = time.time()
    
#     def stop(self):
#         return time.time() - self.start_time
    
# clock = Clock()

# clock.start()
# for i, epoch_num in enumerate(range(100)):
#     params = sgd_minibatch(params, 50, .5)
#     if (i+1) % 20 == 0:
#         print(f'({i+1}) average loss: {avg_training_loss(params)}')
# print(f'minibatch time = {clock.stop()}')

# clock.start()
# for i, epoch_num in enumerate(range(100)):
#     params = sgd_vector_minibatch(params, 50, .5)
#     if (i+1) % 20 == 0:
#         print(f'({i+1}) average loss: {avg_training_loss(params)}')
# print(f'vector_minibatch time = {clock.stop()}')

In [157]:
for i, epoch_num in enumerate(range(10_000)):
    params = sgd_vector_minibatch(params, 60, .5)
    #params = sgd_minibatch(params, 30, .5)
    if (i+1) % 100 == 0:
        print(f'({i+1}) average loss: {avg_training_loss(params)}')
    if (i+1) % 1000 == 0:
        show_prediction_accuracy(params)

(100) average loss: 2.8394389152526855
(200) average loss: 2.632817029953003
(300) average loss: 2.053027868270874
(400) average loss: 1.8389965295791626
(500) average loss: 2.0067193508148193
(600) average loss: 1.8944953680038452
(700) average loss: 1.8663491010665894
(800) average loss: 1.860311508178711
(900) average loss: 1.717785358428955
(1000) average loss: 1.8699570894241333
861/1000
(1100) average loss: 1.7482151985168457
(1200) average loss: 1.866131067276001
(1300) average loss: 1.701192855834961
(1400) average loss: 1.625333547592163
(1500) average loss: 1.633105754852295
(1600) average loss: 1.9269949197769165
(1700) average loss: 1.746794581413269
(1800) average loss: 1.5866358280181885
(1900) average loss: 1.6834129095077515
(2000) average loss: 1.8284664154052734
896/1000
(2100) average loss: 1.7812992334365845
(2200) average loss: 1.6934010982513428
(2300) average loss: 1.584769368171692
(2400) average loss: 1.6393221616744995
(2500) average loss: 1.5406081676483154
(

In [158]:
for i, epoch_num in enumerate(range(10_000)):
    params = sgd_vector_minibatch(params, 60, .005)
    #params = sgd_minibatch(params, 30, .5)
    if (i+1) % 100 == 0:
        print(f'({i+1}) average loss: {avg_training_loss(params)}')
    if (i+1) % 1000 == 0:
        show_prediction_accuracy(params)

(100) average loss: 1.4572159051895142
(200) average loss: 1.490925669670105
(300) average loss: 1.466302514076233
(400) average loss: 1.4840394258499146
(500) average loss: 1.4379446506500244
(600) average loss: 1.5711660385131836
(700) average loss: 1.4854987859725952
(800) average loss: 1.3935836553573608
(900) average loss: 1.603747010231018
(1000) average loss: 1.467980980873108
943/1000
(1100) average loss: 1.6262611150741577
(1200) average loss: 1.5837026834487915
(1300) average loss: 1.4796726703643799
(1400) average loss: 1.4464231729507446
(1500) average loss: 1.582987904548645
(1600) average loss: 1.5626029968261719
(1700) average loss: 1.5450429916381836
(1800) average loss: 1.393280267715454
(1900) average loss: 1.6723823547363281
(2000) average loss: 1.405208706855774
945/1000
(2100) average loss: 1.4778988361358643
(2200) average loss: 1.5622962713241577
(2300) average loss: 1.435713291168213
(2400) average loss: 1.560678243637085
(2500) average loss: 1.4793624877929688
