# Imports 

In [1]:
%load_ext autoreload
%autoreload 2
import jax.numpy as jnp
from jax import grad
import jax
import numpy as np
import math
%matplotlib inline
from grad import *
from model import *

# Test of my autodiff library 

In [None]:
x1 = Number(1)
x2 = Number(2)
x3 = Number(3)
x4 = Number(4)
x5 = Number(5)

y = (x1+x2)*x5 - x3
top_sorted = topo_sort(y)
print(top_sorted)

y.null_gradients()

for num in top_sorted:
    num.backprop_single()
    
print([(i.grad, i) for i in top_sorted])

y.null_gradients()
y.backprop()

print( [(i.grad, i) for i in top_sorted], "key")

In [30]:
testmine = Number(2.)
mysigmoid = 1/(1+math.e**-testmine)

mysigmoid.backprop(should_print=False)
print(topo_sort(mysigmoid))

def jaxsigmoidsum(x):
    x = jnp.sum(x)
    return 1 / (1 + jnp.exp(-x)) 

testjax = jnp.array([2.]) 
sigmoided_value, grads = jax.value_and_grad(jaxsigmoidsum, argnums=(0))(testjax)

print(f"value comparison:", f"Mine {mysigmoid}", f"Jax {sigmoided_value}")
print(f"Grad comparison:", f"Mine {[testmine.grad]}", f"Jax{grads}")

[Number(0.8807970779778823), Number(1.1353352832366128), Number(0.1353352832366127), Number(-2.0), Number(2.0), Number(-1), Number(2.718281828459045), Number(1), Number(1)]
value comparison: Mine Number(0.8807970779778823) Jax 0.8807970285415649
Grad comparison: Mine [0.1049935854035065] Jax[0.10499357]


In [None]:
def jax_sigmoid(x):
    return jnp.vectorize(lambda x: 1/(1+math.e**-x))(x)

def jax_weight_matrix(shape, naive=False):
    """weight matrix thingy.give dims. Not 0."""
    number = 1
    if(type(shape) == int):
        shape = [shape]
    for i in shape:
        number*= i
    if naive:
        return jnp.array([(i / 10) for i in range(number)]).reshape(*shape)
    return np.array([np.random.uniform(low=-.2, high=.2, size=None) for i in range(number)]).reshape(*shape)
    # return np.array([variable(np.random.uniform(low=-.2, high=.2, size=None)) for i in range(sizes[0] * sizes[1])).reshape(*shape)

In [31]:
test_shape = (3,5)
test_jax = jax_weight_matrix(test_shape, naive=True)
test_mine = weight_matrix(test_shape, naive=True)
test_shape2 = (5, 2)
test_jax2 = jax_weight_matrix(test_shape2, naive=True)
test_mine2 = weight_matrix(test_shape2, naive=True)

my_matmul = np.sum(test_mine @ test_mine2)
def j_matmul(a, b):
    thing = a @ b
    return jnp.sum(thing)


print(my_matmul)
print(j_matmul(test_jax, test_jax2))

j_matmuled, grads = jax.value_and_grad(j_matmul, argnums=(0, 1))(test_jax, test_jax2)

my_matmul.backprop()

print(grads[1].flatten())
print([thing.grad for thing in test_mine2.flat])
#These match almost exactly! Yay!

Number(10.65)
10.65
[1.5       1.5       1.8000001 1.8000001 2.1       2.1       2.4
 2.4       2.6999998 2.6999998]
[1.5, 1.5, 1.8, 1.8, 2.0999999999999996, 2.0999999999999996, 2.4000000000000004, 2.4000000000000004, 2.7, 2.7]


# Overfitting a single image

In [2]:
import tensorflow as tf
import keras
%load_ext autoreload
%autoreload 2

import numpy as np


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data(path="mnist.npz", )
indices = np.arange(len(x_train))
np.random.shuffle(indices)
x_train = x_train[indices]
y_train = y_train[indices]

def batch(x, y, batch_size=32):
    if len(x) % batch_size != 0:
        x = x[:batch_size * (len(x)//batch_size)]
        y=y[:batch_size*(len(x)//batch_size)]
    return np.array_split(x, len(x) / batch_size, axis=0), np.array_split(y, len(y)/batch_size, axis=0)

def fix_data(x, y):
    x = x.reshape(x.shape[0], 28*28)/255
    test = np.zeros((x.shape[0], 10))
    test[np.arange(x.shape[0]),y] = 1
    return (x, test)

fixed_x, fixed_y = fix_data(x_train[:1000], y_train[:1000])
b_x , b_y = batch(fixed_x, fixed_y, 32)


In [5]:

my_model = Model(28*28, 10, [ 8, 16])
datas = []
for _epoch in range(10):
    print(f"starting epoch {_epoch}")
    datas.append(my_model.train_epoch(b_x, b_y, lr=1, timer=False, batch_timer=False))
#as you can see loss does go down and it manages to predict the single image.

starting epoch 0
[[Number(5.1010880362035633e-08) Number(9.338977984736345e-09)
  Number(2.8903566673143343e-06) Number(5.126523416196103e-08)
  Number(2.3552644732446952e-07) Number(2.027152204304169e-05)
  Number(3.384781919372198e-08) Number(6.918581099002038e-10)
  Number(5.1484685790866545e-09) Number(1.5024589695836614e-07)]
 [Number(0.0002684087058041381) Number(0.00010845026752175577)
  Number(0.0015633470709564435) Number(0.00020293521255668952)
  Number(0.00036582101973399185) Number(0.0028623524842888164)
  Number(0.0001443077506618142) Number(4.3047795436393466e-05)
  Number(0.00011214943173866973) Number(0.0006296488037372414)]
 [Number(1.918242019791202e-06) Number(4.772063647064944e-07)
  Number(4.048811685754022e-05) Number(1.7073169609073928e-06)
  Number(5.208912030645319e-06) Number(0.00015495049033421122)
  Number(1.1436706530865938e-06) Number(7.167557464409456e-08)
  Number(3.626945727450202e-07) Number(5.395131553999207e-06)]
 [Number(3.766422478893002e-05) Numbe

KeyboardInterrupt: 

In [None]:
#Code to display img

# Attempt to actually train it on multiple images

For reasons that are embedded in my terrible architecture decisions, this is very long. (These reasons include the topo sort not being cached, which would be fairly difficult to implement due to how Numbers() are created. I did not make this homemade autodiff library for speed or even to truly train something; I made it to understand autodiff, which I think it has suceeded in doing, as demonstrated by above cells.)

Running any of the cells below may result in it taking over half an hour to an hour to truly train.

In [36]:
full_x, full_y = fix_data(x_train[:10000], y_train[:10000])
full_b_x , full_b_y = batch(fixed_x, fixed_y)

my_model = Model(28*28, 10, [ 8, 16])
datas = []
for _epoch in range(30):
    print(f"starting epoch {_epoch}")
    datas.append(my_model.train_epoch(b_x, b_y, lr=1e-2, timer=False, batch_timer=False))

312.0


In [None]:
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline  

fig = plt.figure(figsize=(10, 7))
pic = 1
for i, img in enumerate(x_test[:10]):
  plt.subplot(2, 5, pic)
  plt.axis('off')
  predicted = my_model.fd(img.flat)
  plt.title(f"T {y_test[i]} mine {np.argmax(predicted)} ")
  plt.imshow(img)
  pic+= 1
plt.show()
#60% acc. Considering this is from nearly scratch not terrible 