In [2]:
#e
from pathlib import Path
from urllib.request import urlretrieve
import gzip
import pickle

import torch
import datasets as hfds

DATASETS_CACHE_BASE_PATH = Path("~/.cache/minai/datasets").expanduser()

In [3]:
#e
class SimpleDataset:
    def __init__(self, xs, ys):
        self.xs = xs
        self.ys = ys
        assert len(xs) == len(ys)
        self.len = len(xs)

    def __len__(self):
        return self.len
    
    def __getitem__(self, i):
        assert type(i) is int
        return self.xs[i], self.ys[i]
    
    def __repr__(self):
        return f"SimpleDataset(len={self.len}, "\
            f"xs={type(self.xs[0]).__qualname__}, "\
            f"ys={type(self.ys[0]).__qualname__})"

In [4]:
#e
def load_mnist():
    MNIST_URL = "https://github.com/mnielsen/neural-networks-and-deep-learning/raw/master/data/mnist.pkl.gz"
    
    path_data = DATASETS_CACHE_BASE_PATH / "MNIST"
    path_data.mkdir(exist_ok=True, parents=True)
    path_zip = path_data / "mnist.zip"

    if not path_zip.exists(): 
        print(f"Downloading file to {path_zip}")
        urlretrieve(MNIST_URL, path_zip)

    with gzip.open(path_zip) as f: 
        (x_train, y_train), (x_val, y_val), (x_test, y_test) = pickle.load(f, encoding="latin-1")
        x_train, y_train, x_val, y_val, x_test, y_test = map(torch.tensor, (x_train, y_train, x_val, y_val, x_test, y_test))
    
    dsd = dict(
        train=SimpleDataset(x_train, y_train),
        valid=SimpleDataset(x_val, y_val),
        test=SimpleDataset(x_test, y_test)
    )
    
    return dsd

In [5]:
mnist = load_mnist()
mnist

{'train': SimpleDataset(len=50000, xs=Tensor, ys=Tensor),
 'valid': SimpleDataset(len=10000, xs=Tensor, ys=Tensor),
 'test': SimpleDataset(len=10000, xs=Tensor, ys=Tensor)}

In [6]:
#e
class HF_DATASETS:
    FASHION_MNIST = "fashion_mnist"
    TINY_IMAGENET = "zh-plus/tiny-imagenet"

hf_load = hfds.load_dataset

In [7]:
fashion_mnist = hf_load(HF_DATASETS.FASHION_MNIST)

Found cached dataset fashion_mnist (/home/nblzv/.cache/huggingface/datasets/fashion_mnist/fashion_mnist/1.0.0/0a671f063342996f19779d38c0ab4abef9c64f757b35af8134b331c294d7ba48)


  0%|          | 0/2 [00:00<?, ?it/s]

In [8]:
fashion_mnist

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 60000
    })
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 10000
    })
})

In [9]:
import z_export
z_export.export()

Processing minai_nbs/datasets.ipynb -> minai/minai/datasets.py  |  4 cells exported, took 0.001s 
Processing minai_nbs/sampler.ipynb -> minai/minai/sampler.py  |  same contents, skipping, took 0.000s
Processing minai_nbs/setup+template.py -> minai/setup.py  |  same contents, skipping, took 0.000s
Processing minai_nbs/__init__+template.py -> minai/minai/__init__.py  |  same contents, skipping, took 0.000s
Processing minai_nbs/plot.ipynb -> minai/minai/plot.py  |  same contents, skipping, took 0.000s
Processing minai_nbs/mintils.py -> minai/minai/mintils.py  |  same contents, skipping, took 0.000s
Processing minai_nbs/data.ipynb -> minai/minai/data.py  |  same contents, skipping, took 0.001s

All done... took 0.003s
  lib_name: minai
  author: nblzv
  version: 0.1.1
