# MNIST-minus-minus: Train and test baselines for punky datasets

## Authors
- **David W Hogg** (NYU) (Flatiron)
- **Soledad Villar** (JHU)

## To-Do / Bugs:
- Possibly switch the MLP to a CNN?
- How to assess / check that the model is doing a good job where it appears to be?
 - For example, can we find the orientations of 6s and 9s in a pure 6-9++ sample?
- Need to apply learned transformations to the test data after learning the group elements.
- Figure out how to run on MNIST+4 labels.
 - Maybe 4 classifications on each of the 4 labels separately?
- Figure out how to run on MNIST+Inf group elements.
 - Maybe just switch to a regression with 4 outputs?
- Maybe implement a "reasoning" system that learns orientation first, applies it, and then does image contents?

## Notes
- We got some of the sample code from the (excellent) *jax* examples documentation.

In [1]:
import numpy as np
import jax.numpy as jnp
from jax import grad, jit, vmap, random
from jax.scipy.special import logsumexp
import tensorflow as tf
import tensorflow_datasets as tfds
import time
import os
import ssl
from urllib.request import urlopen
from shutil import copyfileobj
import gzip
import pickle

## Read in all MNIST++ datasets

In [2]:
ssl._create_default_https_context = ssl._create_unverified_context
baseurl = "https://cosmo.nyu.edu/hogg/research/2023/04/17/"

In [3]:
def get_and_read_pickle(filename, clobber=False):
    if clobber or not os.path.isfile(filename):
        with urlopen(baseurl + filename) as in_stream, open(filename, 'wb') as out_file:
            copyfileobj(in_stream, out_file)
    with gzip.open(filename, 'rb') as file:
        return pickle.load(file)

In [4]:
# Read SixtyNine++
(X_train69, M_train69, y_train69), (X_test69, M_test69, y_test69) = get_and_read_pickle("SixtyNine++.pkl.gz")

In [5]:
# Read LowRes++
(X_trainLow, M_trainLow, y_trainLow), (X_testLow, M_testLow, y_testLow) = get_and_read_pickle("LowRes++.pkl.gz")

In [14]:
# Read CutOut++
(X_trainCut, M_trainCut, y_trainCut), (X_testCut, M_testCut, y_testCut) = get_and_read_pickle("CutOut++.pkl.gz")

## Set up MLP model
*Note:* Most of this code is copied from the *jax* documentation.

In [6]:
def relu(x):
  return jnp.maximum(0, x)

def predict(params, image):
  # per-example predictions
  activations = image
  for w, b in params[:-1]:
    outputs = jnp.dot(w, activations) + b
    activations = relu(outputs)
  
  final_w, final_b = params[-1]
  logits = jnp.dot(final_w, activations) + final_b
  return logits - logsumexp(logits)

In [7]:
# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_layer_params(m, n, key, scale=1e-2):
  w_key, b_key = random.split(key)
  return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
  keys = random.split(key, len(sizes))
  return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

In [8]:
def one_hot(X, label_list, dtype=jnp.float32):
    """Create a one-hot encoding"""
    foo = jnp.array([x == label_list for x in X], dtype)
    while len(foo.shape) > 2:
        foo = foo.all(axis=-1)
    return foo

# Make a batched version of the `predict` function
batched_predict = vmap(predict, in_axes=(None, 0))

def accuracy(params, images, targets):
    target_class = jnp.argmax(targets, axis=1)
    predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
    return jnp.mean(predicted_class == target_class)

def loss(params, images, targets):
    preds = batched_predict(params, images)
    return -jnp.mean(preds * targets)

@jit
def update(params, x, y, step_size):
    grads = grad(loss)(params, x, y)
    return [(w - step_size * dw, b - step_size * db)
              for (w, b), (dw, db) in zip(params, grads)]

In [9]:
# Ensure TF does not see GPU and grab all GPU memory.
tf.config.set_visible_devices([], device_type='GPU')

In [10]:
# set some parameters
def train_and_test(X_train, y_train, X_test, y_test, epochs=8):
    assert len(X_train) == len(y_train)
    assert len(X_test) == len(y_test)

    num_pixels = X_train[0].shape[0] * X_train[0].shape[1]
    label_list = np.unique(y_train, axis=0)
    num_labels = len(label_list)
    layer_sizes = [num_pixels, 512, 512, num_labels] # MAGIC
    step_size = 0.01 # MAGIC
    batch_size = 128 # MAGIC
    n_targets = num_labels
    print("Found {} distinct labels, and input images with {} pixels".format(num_labels, num_pixels))
    print("The labels are {}".format(label_list))
    
    train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))
    train_images = jnp.reshape(X_train, (len(X_train), num_pixels))
    train_labels = one_hot(y_train, label_list)
    test_images = jnp.reshape(X_test, (len(X_test), num_pixels))
    test_labels = one_hot(y_test, label_list)
    print("Now train_labels is {}".format(train_labels.shape))
    print("Their sums are {}".format(np.sum(train_labels, axis=0)))

    params = init_network_params(layer_sizes, random.PRNGKey(0))
    for epoch in range(epochs):
        start_time = time.time()
        for x, y in tfds.as_numpy(train_dataset.batch(batch_size).prefetch(1)):
            x = jnp.reshape(x, (len(x), num_pixels))
            y = one_hot(y, label_list)
            params = update(params, x, y, step_size)
        epoch_time = time.time() - start_time

        train_acc = accuracy(params, train_images, train_labels)
        test_acc = accuracy(params, test_images, test_labels)
        print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
        print("Training set accuracy {}".format(train_acc))
        print("Test set accuracy {}".format(test_acc))

## Train and test MLP model on the 5 easy cases

In [None]:
# The Challenge nobody asked for
train_and_test(X_trainCut, y_trainCut, X_testCut,  y_testCut, epochs=32)

Found 10 distinct labels, and input images with 784 pixels
The labels are [0 1 2 3 4 5 6 7 8 9]
Now train_labels is (60000, 10)
Their sums are [5923. 6742. 5958. 6131. 5842. 5421. 5918. 6265. 5851. 5949.]
Epoch 0 in 3.51 sec
Training set accuracy 0.5878999829292297
Test set accuracy 0.5920999646186829
Epoch 1 in 3.81 sec
Training set accuracy 0.6246333122253418
Test set accuracy 0.6299999952316284
Epoch 2 in 4.14 sec
Training set accuracy 0.6437000036239624
Test set accuracy 0.6477999687194824
Epoch 3 in 4.04 sec
Training set accuracy 0.6564000248908997
Test set accuracy 0.6554999947547913
Epoch 4 in 4.51 sec
Training set accuracy 0.6644333600997925
Test set accuracy 0.6626999974250793
Epoch 5 in 3.80 sec
Training set accuracy 0.6702166795730591
Test set accuracy 0.6687999963760376
Epoch 6 in 5.85 sec
Training set accuracy 0.67535001039505
Test set accuracy 0.6729999780654907
Epoch 7 in 5.18 sec
Training set accuracy 0.6796833276748657
Test set accuracy 0.6773999929428101
Epoch 8 in 5.

In [None]:
# The Challenge nobody asked for
train_and_test(X_trainLow, y_trainLow, X_testLow,  y_testLow, epochs=32)