### Using `TensorDataset` and `DataLoader`
 + https://pytorch.org/tutorials/beginner/nn_tutorial.html

In [None]:
%matplotlib inline

from pathlib import Path
import requests
import pickle
import numpy as np
import matplotlib.pyplot as plt

import gzip

import torch
from torch.utils.data import TensorDataset, DataLoader

In [None]:
data_dir = '../data'

### Get data

In [None]:
url = 'http://deeplearning.net/data/mnist/'
filename = "mnist.pkl.gz"

path = Path(data_dir, filename)

if not path.exists():
    r = requests.get(url + filename)
    with open(path, 'wb') as h:
        h.write(r.content)

with gzip.open(path, "rb") as h:
    (x_train, y_train), (x_valid, y_valid), _ = pickle.load(h, encoding="latin-1") 

### Initialize a dataloader

In [None]:
bs = 32

train_ds = TensorDataset(torch.FloatTensor(x_train), 
                         torch.FloatTensor(y_train))

valid_ds = TensorDataset(torch.FloatTensor(x_valid), 
                         torch.FloatTensor(y_valid))

train_dl = DataLoader(train_ds, batch_size=bs)
valid_dl = DataLoader(valid_ds, batch_size=bs)

In [None]:
# `DataLoader` provides a convenient way for iterating over `TensorDataset`
for k_train, (xb, yb) in enumerate(train_dl): pass
for k_valid, (xb, yb) in enumerate(valid_dl): pass

print('[#batches] train: {}, valid: {}'.format(k1, k2))

### Visualize random samples

In [None]:
def show_random_samples(dataset, rows=5, cols=5, width=8, height=8):
    """Show rows * cols random samples."""
    axes = plt.subplots(rows, cols, figsize=(width, height))[1].flatten()

    # by default batch_size=1 in DataLoader
    for ax, (x, y) in zip(axes, DataLoader(train_ds, shuffle=True)):
        ax.imshow(x.reshape((28, 28)), cmap="gray")
        
show_random_samples(train_ds)