# Make Dataset MNIST

In [1]:
import pathlib
import struct

import numpy as np
import matplotlib.pyplot as plt

import torch
from torchvision.datasets import MNIST

In [2]:
data_storage_path = pathlib.Path().cwd().parents[1] / 'data'
data_storage_path_mnist = data_storage_path / 'external' / 'mnist'

## Download

In [3]:
dataset = MNIST(
    str(data_storage_path_mnist),
    download=True,
    train=True,
    transform=None,
    )

## Load

In [11]:
def readimage(fz):
    with open(fz, 'rb') as f:
        header = f.read(16)
        mn, num, nrow, ncol = struct.unpack('>4i', header)
        assert mn == 2051
        im = np.empty((num, nrow, ncol))
        npixel = nrow * ncol
        for i in range(num):
            buf = struct.unpack('>%dB' % npixel, f.read(npixel))
            im[i, :, :] = np.asarray(buf).reshape((nrow, ncol))
    return im

In [20]:
def readlabel(fz):
    with open(fz,'rb') as f:
        header = f.read(8)
        mn, num = struct.unpack('>2i', header)
        assert mn == 2049
        label = np.array(struct.unpack('>%dB' % num, f.read()), dtype=int)
    return label

In [6]:
raw_mnist_path = data_storage_path_mnist / 'MNIST' / 'raw'

In [15]:
train_images_path = 'train-images-idx3-ubyte'
train_labels_path = 'train-labels-idx1-ubyte'
test_images_path = 't10k-images-idx3-ubyte'
test_labels_path = 't10k-labels-idx1-ubyte'

In [16]:
train_images = readimage(str(raw_mnist_path / train_images_path))

In [17]:
train_images.shape

(60000, 28, 28)

In [21]:
train_labels = readlabel(str(raw_mnist_path / train_labels_path))

In [22]:
train_labels.shape

(60000,)

In [23]:
test_images = readimage(str(raw_mnist_path / test_images_path))

In [24]:
test_images.shape

(10000, 28, 28)

In [25]:
test_labels = readlabel(str(raw_mnist_path / test_labels_path))

In [26]:
test_labels.shape

(10000,)

## Extract for Anomaly Detection

In [31]:
label_normal = 1
label_abnormal = 9

In [32]:
idx_train = np.where(train_labels == label_normal)

In [33]:
train_images_ad = train_images[idx_train]

In [34]:
train_images_ad.shape

(6742, 28, 28)

In [50]:
np.save(
    str(data_storage_path / 'processed' / 'mnist' / 'train_images.npy'),
    train_images_ad,
    )

In [35]:
train_labels_ad = train_labels[idx_train]

In [36]:
train_labels_ad.shape

(6742,)

In [38]:
train_labels_ad[:10]

array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

In [51]:
np.save(
    str(data_storage_path / 'processed' / 'mnist' / 'train_labels.npy'),
    train_labels_ad,
    )

In [44]:
idx_test = np.where(
    (test_labels == label_normal)
    |
    (test_labels == label_abnormal)    
    )

In [45]:
test_images_ad = test_images[idx_test]

In [46]:
test_images_ad.shape

(2144, 28, 28)

In [52]:
np.save(
    str(data_storage_path / 'processed' / 'mnist' / 'test_images.npy'),
    test_images_ad,
    )

In [47]:
test_labels_ad = test_labels[idx_test]

In [48]:
test_labels_ad.shape

(2144,)

In [49]:
test_labels_ad[:10]

array([1, 1, 9, 9, 9, 1, 9, 9, 1, 1])

In [53]:
np.save(
    str(data_storage_path / 'processed' / 'mnist' / 'test_labels.npy'),
    test_labels_ad,
    )