In [6]:
import argparse
import gzip
import pathlib
import struct

import numpy as np
import pandas as pd
import requests
from PIL import Image
from tqdm import tqdm_notebook as tqdm

In [7]:
def donwload(urls, path):
    path.mkdir(parents=True, exist_ok=True)
    for url in urls:
        filepath = path / pathlib.Path(url).name
        if not filepath.exists():
            res = requests.get(url)
            if res.status_code == 200:
                with open(filepath, 'wb') as f:
                    f.write(res.content)


def load(paths):
    x_path, y_path = paths
    with gzip.open(x_path) as fx, gzip.open(y_path) as fy:
        fx.read(4)
        fy.read(4)
        N, = struct.unpack('>i', fy.read(4))
        if N != struct.unpack('>i', fx.read(4))[0]:
            raise RuntimeError('wrong pair of MNIST images and labels')
        fx.read(8)

        images = np.empty((N, 784), dtype=np.uint8)
        labels = np.empty(N, dtype=np.uint8)

        for i in  tqdm(range(N)):
            labels[i] = ord(fy.read(1))
            for j in range(784):
                images[i, j] = ord(fx.read(1))
    return images, labels


def make_images(path, images, labels):
    path.mkdir(parents=True, exist_ok=True)
    for (i, image), label in zip(enumerate(images), labels):
        filepath = path / '{}_{}.jpg'.format(label, i)
        Image.fromarray(image.reshape(28, 28)).save(filepath)


def make_labellist(path, kind, labels):
    path.mkdir(parents=True, exist_ok=True)
    filepaths = [
        '{}_{}.jpg'.format(label, i) for i, label in enumerate(labels)
    ]
    df = pd.DataFrame({'name': filepaths, 'target': labels.tolist()})
    df.to_csv(path / '{}.csv'.format(kind), index=False, header=False)


def main(path='./src/data', out="jpg"):
    path = pathlib.Path(path)

    def pipeline(kind, path):
        _kind = kind
        if kind == 'test':
            _kind = 't10k'

        baseurl = 'http://yann.lecun.com/exdb/mnist'
        urls = [
            '{}/{}-images-idx3-ubyte.gz'.format(baseurl, _kind),
            '{}/{}-labels-idx1-ubyte.gz'.format(baseurl, _kind)
        ]
        donwload(urls, path / 'raw')

        paths = [
            path / 'raw' / '{}-images-idx3-ubyte.gz'.format(_kind),
            path / 'raw' / '{}-labels-idx1-ubyte.gz'.format(_kind)
        ]
        images, labels = load(paths)

        if out == 'jpg':
            path = path / 'mnist'
            make_images(path / 'images' / kind, images, labels)
            make_labellist(path / 'labels', kind, labels)
        else:
            path = path / 'mnist' / 'npz'
            path.mkdir(parents=True, exist_ok=True)
            np.savez_compressed(
                path / '{}.npz'.format(kind), x=images, y=labels)

    print('Processing train data ...')
    pipeline('train', path)

    print('Processing test data ...')
    pipeline('test', path)

In [8]:
main()

Processing train data ...



Processing test data ...



