# MNIST-minus-minus: Make figures

A handwritten-digit reading task, now with more chaos!

## Authors
- **David W Hogg** (NYU) (Flatiron)
- **Soledad Villar** (JHU)

## To-Do / Bugs:
- What?

## Notes
- null

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pickle
import gzip
import os
import matplotlib as mpl

In [None]:
suffix = ".png" # suffix for plot filenames
figkwargs = {"figsize": (12,12.8),
             "layout": "tight",
            }
mpl.rcParams['text.color'] = 'r'
mpl.rcParams['xtick.color'] = 'r'
mpl.rcParams['ytick.color'] = 'r'
mpl.rcParams['axes.edgecolor'] = 'r'

In [None]:
baseurl = "https://cosmo.nyu.edu/hogg/research/2023/04/17/"

In [None]:
def get_and_read_pickle(filename):
    if not os.path.isfile(filename):
        os.system("wget --no-check-certificate " + baseurl + filename)
    with gzip.open(filename, 'rb') as file:
        return pickle.load(file)

In [None]:
def plot36(Xs, ys, name, Ms=None, fp=None):
    plt.figure(**figkwargs)
    zz = Xs.shape[1] / 2.0 - 0.5 # MAGIC?
    dd = 20. # MAGIC
    for i in range(36):
        plt.subplot(6,6,i+1)
        plt.imshow(Xs[i], cmap='gray_r', interpolation='none')
        xlim = plt.xlim()
        ylim = plt.ylim()
        if Ms is not None:
            plt.plot([zz, zz + dd * Ms[i, 1, 1]],
                     [zz, zz + dd * Ms[i, 1, 0]], "r-", lw=1)
            plt.plot([zz, zz + dd * Ms[i, 0, 1]],
                     [zz, zz + dd * Ms[i, 0, 0]], "r:", lw=2)
        plt.xlim(xlim)
        plt.ylim(ylim)
        plt.title("{} class {}".format(name, ys[i]))
        ax = plt.gca()

        # OKAY axis crazy
        ax.xaxis.set_tick_params(labelbottom=False)
        ax.yaxis.set_tick_params(labelleft=False)
        ax.set_xticks([])
        ax.set_yticks([])

    if fp is not None:
        plt.savefig(fp + suffix)

In [None]:
# Read Fashion++
(X_trainf, M_trainf, y_trainf), (X_testf, M_testf, y_testf) = get_and_read_pickle("Fashion++.pkl.gz")
print(X_trainf.shape, M_trainf.shape, y_trainf.shape,
      X_testf.shape,  M_testf.shape,  y_testf.shape)

In [None]:
plot36(X_trainf, y_trainf, "Fashion++", fp="Fashion++")

In [None]:
# Read MNIST+4
(X_train4, M_train4, y_train4), (X_test4, M_test4, y_test4) = get_and_read_pickle("MNIST+4.pkl.gz")
print(X_train4.shape, M_train4.shape, y_train4.shape,
      X_test4.shape,  M_test4.shape,  y_test4.shape)

In [None]:
plot36(X_train4, y_train4, "MNIST+4", Ms=M_train4, fp="MNIST+4")

In [None]:
# Read MNIST+9
(X_train9, M_train9, y_train9), (X_test9, M_test9, y_test9) = get_and_read_pickle("MNIST+9.pkl.gz")
print(X_train9.shape, M_train9.shape, y_train9.shape,
      X_test9.shape,  M_test9.shape,  y_test9.shape)

In [None]:
plot36(X_train9, y_train9, "MNIST+9", Ms=M_train9, fp="MNIST+9")

In [None]:
# Read MNIST+Inf
(X_trainInf, M_trainInf, y_trainInf), (X_testInf, M_testInf, y_testInf) = get_and_read_pickle("MNIST+Inf.pkl.gz")
print(X_trainInf.shape, M_trainInf.shape, y_trainInf.shape,
      X_testInf.shape,  M_testInf.shape,  y_testInf.shape)

In [None]:
plot36(X_trainInf, y_trainInf, "MNIST+Inf", Ms=M_trainInf, fp="MNIST+Inf")