# MNIST-minus-minus: Make figures for punky datasets

A handwritten-digit reading task, now with more chaos!

## Authors
- **David W Hogg** (NYU) (Flatiron)
- **Soledad Villar** (JHU)

## To-Do / Bugs:
- What?

## Notes
- null

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pickle
import gzip
import os
import matplotlib as mpl

In [None]:
suffix = ".png" # suffix for plot filenames
figkwargs = {"figsize": (12,12.8),
             "layout": "tight",
            }
mpl.rcParams['text.color'] = 'r'
mpl.rcParams['xtick.color'] = 'r'
mpl.rcParams['ytick.color'] = 'r'
mpl.rcParams['axes.edgecolor'] = 'r'

In [None]:
baseurl = "https://cosmo.nyu.edu/hogg/research/2023/04/17/"

In [None]:
def get_and_read_pickle(filename, clobber=False):
    if clobber or not os.path.isfile(filename):
        os.system("wget --no-check-certificate " + baseurl + filename)
    with gzip.open(filename, 'rb') as file:
        return pickle.load(file)

In [None]:
def plot36(Xs, ys, name, Ms=None, fp=None, vmin=0, vmax=255):
    plt.figure(**figkwargs)
    zz = Xs.shape[1] / 2.0 - 0.5 # MAGIC?
    dd = 20. # MAGIC
    for i in range(36):
        plt.subplot(6,6,i+1)
        plt.imshow(Xs[i], cmap='gray_r', interpolation='none', vmin=vmin, vmax=vmax)
        xlim = plt.xlim()
        ylim = plt.ylim()
        if Ms is not None:
            plt.plot([zz, zz + dd * Ms[i, 1, 1]],
                     [zz, zz + dd * Ms[i, 1, 0]], "r-", lw=1)
            plt.plot([zz, zz + dd * Ms[i, 0, 1]],
                     [zz, zz + dd * Ms[i, 0, 0]], "r:", lw=2)
        plt.xlim(xlim)
        plt.ylim(ylim)
        plt.title("{} class {}".format(name, ys[i]))
        ax = plt.gca()

        # OKAY axis crazy
        ax.xaxis.set_tick_params(labelbottom=False)
        ax.yaxis.set_tick_params(labelleft=False)
        ax.set_xticks([])
        ax.set_yticks([])

    if fp is not None:
        plt.savefig(fp + suffix)

In [None]:
# Read SixtyNine++
(X_train69, M_train69, y_train69), (X_test69, M_test69, y_test69) = get_and_read_pickle("SixtyNine++.pkl.gz")

In [None]:
plot36(X_train69, y_train69, "SixtyNine++", Ms=M_train69, fp="SixtyNine++")

In [None]:
# Read LowRes++
(X_trainLow, M_trainLow, y_trainLow), (X_testLow, M_testLow, y_testLow) = get_and_read_pickle("LowRes++.pkl.gz")

In [None]:
plot36(X_trainLow, y_trainLow, "LowRes++", fp="LowRes++")

In [None]:
# Read CutOut++
(X_trainCut, M_trainCut, y_trainCut), (X_testCut, M_testCut, y_testCut) = get_and_read_pickle("CutOut++.pkl.gz")

In [None]:
plot36(X_trainCut, y_trainCut, "CutOut++", fp="CutOut++")

In [None]:
# Read Projections++
(X_trainProj, M_trainProj, y_trainProj), (X_testProj, M_testProj, y_testProj) = get_and_read_pickle("Projections++.pkl.gz")

In [None]:
plot36(X_trainProj, y_trainProj, "Projections++", fp="Projections++")

In [None]:
# Read Crops++
(X_trainCrop, M_trainCrop, y_trainCrop), (X_testCrop, M_testCrop, y_testCrop) = get_and_read_pickle("Crops++.pkl.gz")

In [None]:
plot36(X_trainCrop, y_trainCrop, "Crops++", fp="Crops++")