# MNIST-minus-minus: Train and test baselines

A handwritten-digit reading task, now with more chaos!

## Authors
- **David W Hogg** (NYU) (Flatiron)
- **Soledad Villar** (JHU)

## To-Do / Bugs:
- Run on all learning challenges.
- Figure out how to run on MNIST+4 labels.
- Figure out how to run on MNIST+Inf group elements.

## Notes
- null

In [None]:
import numpy as np
import gzip
import pickle
import os
import jax.numpy as jnp
from jax import grad, jit, vmap, random

In [None]:
baseurl = "https://cosmo.nyu.edu/hogg/research/2023/04/17/"

In [None]:
def get_and_read_pickle(filename, clobber=False):
    if clobber or not os.path.isfile(filename):
        os.system("wget --no-check-certificate " + baseurl + filename)
    with gzip.open(filename, 'rb') as file:
        return pickle.load(file)

In [None]:
# Read Fashion++
(X_trainf, M_trainf, y_trainf), (X_testf, M_testf, y_testf) = get_and_read_pickle("Fashion++.pkl.gz")
print(X_trainf.shape, M_trainf.shape, y_trainf.shape,
      X_testf.shape,  M_testf.shape,  y_testf.shape)

In [None]:
# Read MNIST+4
(X_train4, M_train4, y_train4), (X_test4, M_test4, y_test4) = get_and_read_pickle("MNIST+4.pkl.gz")
print(X_train4.shape, M_train4.shape, y_train4.shape,
      X_test4.shape,  M_test4.shape,  y_test4.shape)

In [None]:
# Look at label statistics for MNIST+4
sstr = set(y_train4)
print("total number of labels missing from the training set:", 10000 - len(sstr))
sste = set(y_test4)
print("total number of labels missing from the test set:", 10000 - len(sste))
i = 0
for q in sste:
    if q not in sstr:
        i += 1
        print(i, "label", q, "is in the test set but not in the training set")

In [None]:
# Read MNIST+9
(X_train9, M_train9, y_train9), (X_test9, M_test9, y_test9) = get_and_read_pickle("MNIST+9.pkl.gz")
print(X_train9.shape, M_train9.shape, y_train9.shape,
      X_test9.shape,  M_test9.shape,  y_test9.shape)

In [None]:
# Read MNIST+Inf
(X_trainInf, M_trainInf, y_trainInf), (X_testInf, M_testInf, y_testInf) = get_and_read_pickle("MNIST+Inf.pkl.gz")
print(X_trainInf.shape, M_trainInf.shape, y_trainInf.shape,
      X_testInf.shape,  M_testInf.shape,  y_testInf.shape)

In [None]:
# NOW PACK THE ABOVE INTO tensorflow dataset objects?

In [None]:
from jax.scipy.special import logsumexp

def relu(x):
  return jnp.maximum(0, x)

def predict(params, image):
  # per-example predictions
  activations = image
  for w, b in params[:-1]:
    outputs = jnp.dot(w, activations) + b
    activations = relu(outputs)
  
  final_w, final_b = params[-1]
  logits = jnp.dot(final_w, activations) + final_b
  return logits - logsumexp(logits)

In [None]:
# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_layer_params(m, n, key, scale=1e-2):
  w_key, b_key = random.split(key)
  return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
  keys = random.split(key, len(sizes))
  return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

layer_sizes = [784, 512, 512, 10]
step_size = 0.01
num_epochs = 50
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.PRNGKey(0))

In [None]:
def one_hot(x, k, dtype=jnp.float32):
  """Create a one-hot encoding of x of size k."""
  return jnp.array(x[:, None] == jnp.arange(k), dtype)
  
def accuracy(params, images, targets):
  target_class = jnp.argmax(targets, axis=1)
  predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
  return jnp.mean(predicted_class == target_class)

def loss(params, images, targets):
  preds = batched_predict(params, images)
  return -jnp.mean(preds * targets)

@jit
def update(params, x, y):
  grads = grad(loss)(params, x, y)
  return [(w - step_size * dw, b - step_size * db)
          for (w, b), (dw, db) in zip(params, grads)]

In [None]:
import tensorflow as tf
# Ensure TF does not see GPU and grab all GPU memory.
tf.config.set_visible_devices([], device_type='GPU')

import tensorflow_datasets as tfds

num_pixels = X_trainf[0].shape[0] * X_trainf[0].shape[1]
num_labels = len(set(y_trainf))

In [None]:
# Let's upgrade it to handle batches using `vmap`

# Make a batched version of the `predict` function
batched_predict = vmap(predict, in_axes=(None, 0))

In [None]:
import time

In [None]:
train_dataset = tf.data.Dataset.from_tensor_slices((X_trainf, y_trainf))

train_images = jnp.reshape(X_trainf, (len(X_trainf), num_pixels))
train_labels = one_hot(y_trainf, num_labels)

test_images = jnp.reshape(X_testf, (len(X_testf), num_pixels))
test_labels = one_hot(y_testf, num_labels)

In [None]:
for epoch in range(num_epochs):
  start_time = time.time()
  for x, y in tfds.as_numpy(train_dataset.batch(batch_size).prefetch(1)):
    x = jnp.reshape(x, (len(x), num_pixels))
    y = one_hot(y, num_labels)
    params = update(params, x, y)
  epoch_time = time.time() - start_time

  train_acc = accuracy(params, train_images, train_labels)
  test_acc = accuracy(params, test_images, test_labels)
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))