<a href="https://colab.research.google.com/github/profteachkids/STEMUnleashed2023/blob/main/MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install datasets
!pip install einops

Collecting datasets
  Downloading datasets-2.13.1-py3-none-any.whl (486 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m486.2/486.2 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.7,>=0.3.0 (from datasets)
  Downloading dill-0.3.6-py3-none-any.whl (110 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m110.5/110.5 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
Collecting xxhash (from datasets)
  Downloading xxhash-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (212 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m212.5/212.5 kB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting multiprocess (from datasets)
  Downloading multiprocess-0.70.15-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0.0,>=0.11.0 (from datasets)
  Downloading huggingface_hub-0.16.4-py3-none-a

In [2]:
import jax.numpy as jnp
import jax
# jax.config.update("jax_enable_x64", True)
import numpy as np
import pandas as pd
from plotly.subplots import make_subplots
import plotly.express as px
from einops import rearrange
from scipy.optimize import minimize

In [3]:
from datasets import load_dataset

dataset = load_dataset("mnist", split="train")
test_dataset = load_dataset("mnist", split="test")

Downloading builder script:   0%|          | 0.00/3.98k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/2.21k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/6.83k [00:00<?, ?B/s]

Downloading and preparing dataset mnist/mnist to /root/.cache/huggingface/datasets/mnist/mnist/1.0.0/9d494b7f466d6931c64fb39d58bb1249a4d85c9eb9865d9bc20960b999e2a332...


Downloading data files:   0%|          | 0/4 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/9.91M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/28.9k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.65M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.54k [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/4 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/60000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]

Dataset mnist downloaded and prepared to /root/.cache/huggingface/datasets/mnist/mnist/1.0.0/9d494b7f466d6931c64fb39d58bb1249a4d85c9eb9865d9bc20960b999e2a332. Subsequent calls will reuse this data.




In [4]:
N=5000
train_imgs=np.stack(dataset[:N]['image'])
train_labels=np.zeros((N,10))
train_labels[np.arange(N),dataset[:N]['label']]=1

In [5]:
plot_img=rearrange(train_imgs[:9], '(b1 b2) h w -> (b1 h) (b2 w)', b1=3)

In [6]:
fig=px.imshow(plot_img, color_continuous_scale='gray')
fig.update_layout(width=300,height=300,template='plotly_dark',coloraxis_showscale=False)
fig.show()

In [7]:
train=rearrange(train_imgs, 'b h w -> b (h w)').astype(np.float64)

In [8]:
rng=np.random.RandomState(123)
n0=784
n1=100
n2=10
sqrt6=np.sqrt(6)

w=dict(w1 = jnp.asarray(rng.uniform(-sqrt6/np.sqrt(n0+n1), sqrt6/np.sqrt(n0+n1),size=(n1,n0))), b1=jnp.zeros(n1),
       w2 = jnp.asarray(rng.uniform(-sqrt6/np.sqrt(n1+n2), sqrt6/np.sqrt(n1+n2), size=(n2,n1))), b2=jnp.zeros(n2))




In [10]:
batch_matmul = jax.vmap(jnp.matmul, (None, 0))

In [11]:
def nn(w, data):
    a1=batch_matmul(w['w1'], data) + w['b1']
    o1 = jnp.tanh(a1)
    a2=batch_matmul(w['w2'], o1) + w['b2']
    o2 = jax.nn.softmax(a2)
    return o2

In [12]:
def cross_entropy(w, data, labels):
    return -jnp.sum(labels*jnp.log(nn(w, data)))

In [13]:
cross_entropy_grad=jax.jit(jax.grad(cross_entropy))
jax_cross_entropy=jax.jit(cross_entropy)

In [14]:
cross_entropy_grad(w,train, train_labels)

{'b1': DeviceArray([-0.47642651,  1.39755987,  0.48862499, -0.429673  ,
              -0.83526989,  0.92325612,  1.05935943, -0.22107826,
              -0.79240337, -0.70355668, -0.26882259,  0.59540128,
               0.4767081 , -0.20931971, -1.79844689,  0.1378823 ,
              -2.17180496, -0.95457265, -1.80268246, -0.34693464,
              -2.04853348, -0.01467416,  0.10666417, -0.87681838,
              -0.91743557,  2.38020519, -2.06771059,  1.41664657,
               0.86323236,  1.41703022,  2.62812523,  0.06843469,
              -3.03321682, -0.88455897, -1.67630929, -0.54534676,
               1.01746623,  0.15519708,  0.70397994, -1.66232168,
              -1.02872764, -1.55483326, -1.0769564 , -1.15905825,
               2.26706143, -0.16919138, -1.21595092, -0.66732565,
               0.0055153 , -1.23378156, -0.21055214, -0.92483293,
               2.27655167, -0.25308397,  2.04351365, -1.29248361,
              -1.56817634,  0.85016951,  1.8428055 ,  1.16332651,
    

In [15]:
gsum=jax.tree_util.tree_map(lambda x: jnp.array(0), w)
ydefmax=jax.tree_util.tree_map(lambda x: 1e-3, w)
y0=w
yorig=y0

In [16]:
#Distance over Gradient - Stochastic Gradient Descent
#https://arxiv.org/pdf/2302.12022.pdf


for i in range(int(20e3)):
    g=cross_entropy_grad(y0, train, train_labels)
    gsum=jax.tree_map(lambda x,y: jnp.sum(x**2)+y, g, gsum)
    eta= jax.tree_map(lambda x,y: x/jnp.sqrt(y), ydefmax, gsum)
    y1 = jax.tree_map(lambda x,y,z: x-y*z, y0,eta,g)
    ydelta=jax.tree_map(lambda x,y: jnp.linalg.norm(x-y), y1, yorig)
    ydefmax = jax.tree_map(lambda x,y: x if x > y else y, ydelta, ydefmax)
    y0=y1
    if i % 1000 == 0:
        print(i, jax_cross_entropy(y0, train, train_labels), eta)

0 14248.178581544475 {'b1': DeviceArray(8.09016389e-05, dtype=float64), 'b2': DeviceArray(1.32799309e-06, dtype=float64), 'w1': DeviceArray(3.88183805e-08, dtype=float64), 'w2': DeviceArray(1.57289782e-07, dtype=float64)}
1000 124.97084300310371 {'b1': DeviceArray(0.00342963, dtype=float64), 'b2': DeviceArray(0.00024799, dtype=float64), 'w1': DeviceArray(7.77423028e-06, dtype=float64), 'w2': DeviceArray(0.00042031, dtype=float64)}
2000 50.43947780400822 {'b1': DeviceArray(0.00348423, dtype=float64), 'b2': DeviceArray(0.00031757, dtype=float64), 'w1': DeviceArray(7.86082416e-06, dtype=float64), 'w2': DeviceArray(0.00058987, dtype=float64)}
3000 28.244543917923274 {'b1': DeviceArray(0.00349416, dtype=float64), 'b2': DeviceArray(0.00036546, dtype=float64), 'w1': DeviceArray(7.89358871e-06, dtype=float64), 'w2': DeviceArray(0.0006986, dtype=float64)}
4000 18.483351195607575 {'b1': DeviceArray(0.00350751, dtype=float64), 'b2': DeviceArray(0.0003993, dtype=float64), 'w1': DeviceArray(7.90950

KeyboardInterrupt: ignored

In [17]:
test_N=30
test_imgs=np.stack(test_dataset[:test_N]['image'])
test_data=rearrange(test_imgs, 'b h w -> b (h w)').astype(np.float64)

In [18]:
nn_label=nn(y0,test_data)

In [21]:
list(zip(np.argmax(nn_label,axis=1),np.array(test_dataset[:test_N]['label'])))

[(DeviceArray(7, dtype=int64), 7),
 (DeviceArray(2, dtype=int64), 2),
 (DeviceArray(1, dtype=int64), 1),
 (DeviceArray(0, dtype=int64), 0),
 (DeviceArray(9, dtype=int64), 4),
 (DeviceArray(1, dtype=int64), 1),
 (DeviceArray(4, dtype=int64), 4),
 (DeviceArray(9, dtype=int64), 9),
 (DeviceArray(6, dtype=int64), 5),
 (DeviceArray(9, dtype=int64), 9),
 (DeviceArray(0, dtype=int64), 0),
 (DeviceArray(6, dtype=int64), 6),
 (DeviceArray(9, dtype=int64), 9),
 (DeviceArray(0, dtype=int64), 0),
 (DeviceArray(1, dtype=int64), 1),
 (DeviceArray(5, dtype=int64), 5),
 (DeviceArray(9, dtype=int64), 9),
 (DeviceArray(7, dtype=int64), 7),
 (DeviceArray(1, dtype=int64), 3),
 (DeviceArray(4, dtype=int64), 4),
 (DeviceArray(7, dtype=int64), 9),
 (DeviceArray(6, dtype=int64), 6),
 (DeviceArray(6, dtype=int64), 6),
 (DeviceArray(5, dtype=int64), 5),
 (DeviceArray(4, dtype=int64), 4),
 (DeviceArray(0, dtype=int64), 0),
 (DeviceArray(7, dtype=int64), 7),
 (DeviceArray(4, dtype=int64), 4),
 (DeviceArray(0, dty