# Imports 

In [None]:
%load_ext autoreload
%autoreload 2
import jax.numpy as jnp
from jax import grad
import jax
import numpy as np
import math
%matplotlib inline
from grad import *
from model import *

# Test of my autodiff library 

In [30]:
testmine = Number(2.)
mysigmoid = 1/(1+math.e**-testmine)

mysigmoid.backprop(should_print=False)
print(topo_sort(mysigmoid))

def jaxsigmoid(x):
    x = jnp.sum(x)
    return 1 / (1 + jnp.exp(-x)) 

testjax = jnp.array([2.]) 
sigmoided_value, grads = jax.value_and_grad(jaxsigmoid, argnums=(0))(testjax)

print(f"value comparison:", f"Mine {mysigmoid}", f"Jax {sigmoided_value}")
print(f"Grad comparison:", f"Mine {[testmine.grad]}", f"Jax{grads}")

[Number(0.8807970779778823), Number(1.1353352832366128), Number(0.1353352832366127), Number(-2.0), Number(2.0), Number(-1), Number(2.718281828459045), Number(1), Number(1)]
value comparison: Mine Number(0.8807970779778823) Jax 0.8807970285415649
Grad comparison: Mine [0.1049935854035065] Jax[0.10499357]


In [11]:
def jax_weight_matrix(shape, naive=False):
    """weight matrix thingy.give dims. Not 0."""
    number = 1
    if(type(shape) == int):
        shape = [shape]
    for i in shape:
        number*= i
    if naive:
        return jnp.array([(i / 10) for i in range(number)]).reshape(*shape)
    return np.array([np.random.uniform(low=-.2, high=.2, size=None) for i in range(number)]).reshape(*shape)
    # return np.array([variable(np.random.uniform(low=-.2, high=.2, size=None)) for i in range(sizes[0] * sizes[1])).reshape(*shape)

In [12]:
    
def jax_sigmoid(x):
    return jnp.vectorize(lambda x: 1/(1+math.e**-x))(x)


def jax_relu(x):
    return jnp.where(x <= 0, 1e-2 * x, x)


In [None]:
test_shape = (3,5)
test_jax = jax_weight_matrix(test_shape, naive=True)
test_mine = weight_matrix(test_shape, naive=True)
test_shape2 = (5, 2)
test_jax2 = jax_weight_matrix(test_shape2, naive=True)
test_mine2 = weight_matrix(test_shape2, naive=True)

my_matmul = np.sum(test_mine @ test_mine2)
def j_matmul(a, b):
    thing = a @ b
    return jnp.sum(thing)


print(my_matmul)
print(j_matmul(test_jax, test_jax2))

j_matmuled, grads = jax.value_and_grad(j_matmul, argnums=(0, 1))(test_jax, test_jax2)

my_matmul.backprop()

print(grads[1].flatten())
print([thing.grad for thing in test_mine2.flat])
#These match almost exactly! Yay!

Number(10.65)
10.65
[1.5       1.5       1.8000001 1.8000001 2.1       2.1       2.4
 2.4       2.6999998 2.6999998]
[1.5, 1.5, 1.8, 1.8, 2.0999999999999996, 2.0999999999999996, 2.4000000000000004, 2.4000000000000004, 2.7, 2.7]


In [None]:
   
class Model():
    def __init__(self, input_size, output_size, hidden_layers, naive=False, seed=None):
        '''
        Takes list of # of things in their layers.
        Layers are outputs?
        '''
        if seed != None:
            np.random.seed(seed)

        self.layer_sizes = hidden_layers
        self.layers = []
        self.biases = []
        
        prev_size = input_size
        
        for hidden_layer in hidden_layers:
            self.layers.append(jax_weight_matrix([prev_size, hidden_layer], naive))
            self.biases.append(jax_weight_matrix(hidden_layer, naive))
            prev_size  = hidden_layer
            
        self.biases.append(jax_weight_matrix([output_size]))
        self.layers.append(jax_weight_matrix([prev_size, output_size]))

        self.layers= tuple(self.layers)
        self.biases = tuple(self.biases)  
 
  
    def fd(self, x):
        '''f pass with input. '''

        for i in range(len(self.layers)):
            x = x @ self.layers[i]
            x += self.biases[i]
            if i != len(self.layers) - 1:
                x = jax_relu(x)
            else:
                x = jax_sigmoid(x)
            # self.hidden_states_activation.append(x)

        return x
        
    def loss_static(self, params, x, y):
        '''f pass with for loss.  '''
        w, b = params
        for i in range(len(b)):
            x = x @ w[i]
            x += b[i]
            if i != len(b) - 1:
                x = jax_relu(x)
            else:
                x = jax_sigmoid(x)

        y = jnp.array(y)
        return jnp.sum(x * x - 2 * x * y + y * y)

            
    def train_epoch(self, x, y, lr=10**-2):
        '''
        f pass and then uh gradient descent?

        x:  
        y: the goal. In not sparse tensor.
        lr: how quick it learns
        '''
        losses = []
        x = np.array(x)
        
        for batch_num in range(len(y)):
            mse, grads = jax.value_and_grad(self.loss_static, argnums=(0))((self.layers, self.biases), x[batch_num], y[batch_num])
            
            losses.append(mse)
            
           #0 contains weights and 1 contains the bias grads. 
            # print([i.shape for i in grads[1]])
            # print([i.shape for i in grads[0]])

            self.layers = list(self.layers)
            self.biases = list(self.biases)
            

            for i, (layer, grad_layer) in enumerate(zip(self.layers, grads[0])):
                self.layers[i] = layer - lr * grad_layer  
                
            for i, (bias, grad_bias) in enumerate(zip(self.biases, grads[1])):
                self.biases[i] = bias - lr * grad_bias   
                
        preds = self.fd(x[batch_num]) 
        
        correct = jnp.sum(jnp.argmax(preds, axis=1) == jnp.argmax(y[batch_num], axis=1))
        acc = correct / len(y[batch_num])
        print(f"Acc: {acc} Loss: {mse}")
        return losses

# Creating a model using jax

In [6]:
import tensorflow as tf
import keras
%load_ext autoreload
%autoreload 2

import numpy as np


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data(path="mnist.npz", )
indices = np.arange(len(x_train))
np.random.shuffle(indices)
x_train = x_train[indices]
y_train = y_train[indices]

In [8]:
def batch(x, y, batch_size=32):
    if len(x) % batch_size != 0:
        x = x[:batch_size * (len(x)//batch_size)]
        y=y[:batch_size*(len(x)//batch_size)]
    print(len(x)/batch_size)
    return np.array(np.split(x, int(len(x) / batch_size), axis=0)), np.split(y, int(len(y)/batch_size), axis=0)

def fix_data(x, y):
    x = x.reshape(x.shape[0], 28*28)/255
    test = np.zeros((x.shape[0], 10))
    test[np.arange(x.shape[0]),y] = 1
    return (x, test)

In [9]:
fixed_x, fixed_y = fix_data(x_train[:10000], y_train[:10000])
b_x , b_y = batch(fixed_x, fixed_y)


312.0


In [29]:
my_model = Model(28*28, 10, [8, 16])
datas = []
for _epoch in range(10):
    datas.append(my_model.train_epoch(b_x, b_y, lr=1e-3))

Acc: 0.21875 Loss: 29.365278244018555
Acc: 0.21875 Loss: 29.098339080810547
Acc: 0.34375 Loss: 28.30455780029297
Acc: 0.40625 Loss: 24.284332275390625
Acc: 0.4375 Loss: 21.402801513671875
Acc: 0.5 Loss: 20.755481719970703


KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline  

fig = plt.figure(figsize=(10, 7))
pic = 1
for i, img in enumerate(x_test[:10]):
  plt.subplot(2, 5, pic)
  plt.axis('off')
  predicted = my_model.fd(img.flat)
  keras_pred = model.predict(img.reshape(1, 28*28))
  plt.title(f"T {y_test[i]} mine {np.argmax(predicted)} keras {np.argmax(keras_pred)} ")
  plt.imshow(img)
  pic+= 1
plt.show()
#60% acc. Considering this is from nearly scratch not terrible 