# MNIST data analyses

_Copyright (C) 2023-, Joseph T. Lizier._
_Distributed under GNU General Public License v3_

This notebook template sets up reading in the data files for the MNIST analysis.

First code block is just to start up the JVM. You may need to change your paths below (use the AutoAnalyser to show them for you):

In [None]:
from jpype import *
import numpy
import sys
# Our python data file readers are a bit of a hack, python users will do better on this:
sys.path.append("../../../demos/python")
import readIntsFile

if (not isJVMStarted()):
    # Add JIDT jar library to the path
    jarLocation = "../../../infodynamics.jar"
    # Start the JVM (add the "-Xmx" option with say 1024M if you get crashes due to not enough memory space)
    startJVM(getDefaultJVMPath(), "-ea", "-Djava.class.path=" + jarLocation, convertStrings=True)

Next load in the data, and check that we can pull out a sample digit ok.

We give code for two options to load the data, either:
* Using the `torch` package (may also require `torchvision`) to automatically download the data and load it (_default_), or
* via a Matlab pre-processed `.mat` file.

Either way, the data is loaded into numpy arrays, being:
* `classes` -- a 1D array of the classes 0..9 for each sample;
* `pixels1D` -- binarised pixel data for each sample, with dimensions (numTrials, imageDimension * imageDimension) -- i.e. one dimensional data for each sample. (This is the usual JIDT multivariate data format, with each row being a sample, and each column a variable).
* `pixels2D` -- binarised pixel data for each sample, with dimensions (numTrials, imageDimension, imageDimension) -- i.e. two dimensional image data for each sample.

You can work with whichever of `pixels1D` or `pixels2D` that you prefer.

In [None]:
# Load/prepare the data:

# Choose whether to import via torch or via the distributed Matlab format:
downloadViaTorch = True

if (downloadViaTorch):
    # Option 1: automatically download (into local folder 'data') and import the data using torchvision:
    #  (thanks to Isabelle De Backer for pointing this option out)
    import torchvision
    trainData = torchvision.datasets.MNIST('./data', download=True, train=True)
    testData = torchvision.datasets.MNIST('./data', download=True, train=False)
    rawPixels2D = numpy.row_stack((trainData.data, testData.data))
    imageDimension = rawPixels2D.shape[1];
    numPixels = imageDimension * imageDimension;
    numTrials = rawPixels2D.shape[0];
    rawPixels1D = numpy.reshape(rawPixels2D, (numTrials, imageDimension * imageDimension))
    classes = numpy.concatenate((trainData.targets, testData.targets))
else:
    # Option 2: import the pre-prepared data in Matlab format:
    import scipy.io
    data = scipy.io.loadmat('./trialAndTest-processedData.mat');
    trainAndTestData = numpy.array(data['trialAndTestData'])
    classes = trainAndTestData[:,0] - 1; # Need classes to start from 0 for JIDT (and it's sensible to match digits here also)
    rawPixels1D = trainAndTestData[:,1:];
    numPixels = rawPixels1D.shape[1];
    numTrials = rawPixels1D.shape[0];
    import math
    imageDimension = int(math.sqrt(numPixels)); # Will be 28
    rawPixels2D = numpy.reshape(rawPixels1D, (numTrials, imageDimension, imageDimension));

numClasses = 10; # We know we only have digits 0-9 here
print('Loaded MNIST data with %d samples, classes %d:%d, and %d pixels per sample\n' %\
    (numTrials, min(classes), max(classes), numPixels));

# Binarise the pixel data:
threshold = 15; # From viewing histograms of pixel values this seems reasonable.
pixels2D = rawPixels2D > threshold; # In dimensions (numTrials, imageDimension, imageDimension) 
pixels1D = rawPixels1D > threshold; # In dimensions (numTrials, imageDimension * imageDimension)

# And check that we can plot of of the images ok from either the 1D or 2D arrangements:
import matplotlib.pyplot as plt
plt.figure(figsize=(12,6))
plt.subplot(1,2,1) # left subplot
sampleIndex = 2;
plt.imshow( pixels2D[sampleIndex,:,:] ) # Plotting the 3rd digit as a sample
plt.title('Sample %d from pixels2D (digit=%d)' % (sampleIndex, classes[sampleIndex]));
plt.subplot(1, 2, 2); # right one
plt.imshow( numpy.reshape(pixels1D[sampleIndex,:], (imageDimension, imageDimension) )) # Plotting the 3rd digit as a sample
plt.title('Sample %d from pixels1D (digit=%d)' % (sampleIndex, classes[sampleIndex]));

# Your analysis goes next ...: