<a href="https://colab.research.google.com/github/matthew-mcateer/noise-weight-theft/blob/master/qmnist_numpy_prep.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
%config InlineBackend.figure_format = 'retina'

!wget -q https://github.com/facebookresearch/qmnist/raw/master/qmnist-test-images-idx3-ubyte.gz
!wget -q https://github.com/facebookresearch/qmnist/raw/master/qmnist-test-labels-idx1-ubyte.gz
!wget -q https://github.com/facebookresearch/qmnist/raw/master/qmnist-test-labels-idx2-int.gz
!wget -q https://github.com/facebookresearch/qmnist/raw/master/qmnist-test-labels.tsv.gz
!wget -q https://github.com/facebookresearch/qmnist/raw/master/qmnist-train-images-idx3-ubyte.gz
!wget -q https://github.com/facebookresearch/qmnist/raw/master/qmnist-train-labels-idx2-int.gz
!wget -q https://github.com/facebookresearch/qmnist/raw/master/qmnist-train-labels.tsv.gz
!wget -q https://github.com/facebookresearch/qmnist/raw/master/xnist-images-idx3-ubyte.xz
!wget -q https://github.com/facebookresearch/qmnist/raw/master/xnist-labels-idx2-int.xz
!wget -q https://github.com/facebookresearch/qmnist/raw/master/xnist-labels.tsv.xz

!gunzip qmnist-test-images-idx3-ubyte.gz
!gunzip qmnist-test-labels-idx1-ubyte.gz
!gunzip qmnist-test-labels-idx2-int.gz
!gunzip qmnist-test-labels.tsv.gz
!gunzip qmnist-train-images-idx3-ubyte.gz
!gunzip qmnist-train-labels-idx2-int.gz
!gunzip qmnist-train-labels.tsv.gz


gzip: xnist-images-idx3-ubyte.xz: unknown suffix -- ignored
gzip: xnist-labels-idx2-int.xz: unknown suffix -- ignored
gzip: xnist-labels.tsv.xz: unknown suffix -- ignored


In [0]:
import codecs
import numpy as np
import torch

def get_int(b):
    return int(codecs.encode(b, 'hex'), 16)

def open_maybe_compressed_file(path):
    if path.endswith('.gz'):
        return gzip.open(path, 'rb')
    elif path.endswith('.xz'):
        return lzma.open(path, 'rb')
    else:
        return open(path,'rb')
    
def read_idx2_int(path):
    with open_maybe_compressed_file(path) as f:
        data = f.read()
        assert get_int(data[:4]) == 12*256 + 2
        length = get_int(data[4:8])
        width = get_int(data[8:12])
        parsed = np.frombuffer(data, dtype=np.dtype('>i4'), offset=12)
        return torch.from_numpy(parsed.astype('i4')).view(length,width).long().numpy()

def read_idx3_ubyte(path):
    with open_maybe_compressed_file(path) as f:
        data = f.read()
        assert get_int(data[:4]) == 8 * 256 + 3
        length = get_int(data[4:8])
        num_rows = get_int(data[8:12])
        num_cols = get_int(data[12:16])
        parsed = np.frombuffer(data, dtype=np.uint8, offset=16)
        return torch.from_numpy(parsed).view(length, num_rows, num_cols).numpy()

In [0]:
x_train = read_idx3_ubyte('qmnist-train-images-idx3-ubyte')
y_train = read_idx2_int('qmnist-train-labels-idx2-int')
x_test = read_idx3_ubyte('qmnist-test-images-idx3-ubyte')
y_test = read_idx2_int('qmnist-test-labels-idx2-int')
y_train = y_train[:,0].astype('float32')
y_test = y_test[:,0].astype('float32')

print("x_train: {}, {}, {}".format(x_train.shape, type(x_train), type(x_train[0][0][0])))
print("y_train: {}, {}, {}".format(y_train.shape, type(y_train), type(y_train[0])))
print("x_test:  {}, {}, {}".format(x_test.shape, type(x_test), type(x_test[0][0][0])))
print("y_test:  {}, {}, {}".format(y_test.shape, type(y_test), type(y_test[0])))

x_train: (60000, 28, 28), <class 'numpy.ndarray'>, <class 'numpy.uint8'>
y_train: (60000,), <class 'numpy.ndarray'>, <class 'numpy.float32'>
x_test:  (60000, 28, 28), <class 'numpy.ndarray'>, <class 'numpy.uint8'>
y_test:  (60000,), <class 'numpy.ndarray'>, <class 'numpy.float32'>


In [0]:
np.save('qmnist_x_train', x_train)
np.save('qmnist_y_train', y_train)
np.save('qmnist_x_test', x_test)
np.save('qmnist_y_test', y_test)
