In [1]:
import os
from pathlib import Path
from itertools import islice
import gzip
import pickle

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset

In [2]:
dir_path = Path().absolute()
dataset_path = dir_path.parent / "data/mnist.pkl.gz"
if not dataset_path.exists():
    print('Downloading dataset with curl ...')
    if not dataset_path.parent.exists():
        os.mkdir(dataset_path.parent)
    url = 'http://ericjmichaud.com/downloads/mnist.pkl.gz'
    os.system('curl -L {} -o {}'.format(url, dataset_path))
print('Download failed') if not dataset_path.exists() else print('Dataset acquired')
f = gzip.open(dataset_path, 'rb')
mnist = pickle.load(f)
f.close()
print('Loaded data to variable `mnist`')

Dataset acquired
Loaded data to variable `mnist`


In [3]:
device = torch.device('cpu')
dtype = torch.float32
torch.set_default_dtype(dtype)

In [4]:
class MNISTDataset(Dataset):
    """MNIST Digits Dataset."""
    def __init__(self, data, transform=None):
        """We save the dataset images as torch.tensor since saving 
        the dataset in memory inside a `Dataset` object as a 
        python list or a numpy array causes a multiprocessiing-related 
        memory leak."""
        self.images, self.labels = zip(*data)
        self.images = torch.from_numpy(np.array(self.images)).to(dtype)
        self.labels = torch.tensor(np.argmax(self.labels, axis=1)).to(torch.long)
        self.transform = transform
        
    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image, label = self.images[idx], self.labels[idx]
        if self.transform:
            image, label = self.transform((image, label))
        return image, label


In [5]:
train_data = MNISTDataset(mnist[:60000])
test_data = MNISTDataset(mnist[60000:])

training_loader = torch.utils.data.DataLoader(train_data, batch_size=20, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=20, shuffle=True)

In [None]:
for epoch in range(300):
    for sample, target in training_loader:
        del sample, target
    print("Epoch {:3d} completed.".format(epoch))

Epoch   0 completed.
Epoch   1 completed.
Epoch   2 completed.
Epoch   3 completed.
Epoch   4 completed.
Epoch   5 completed.
Epoch   6 completed.
Epoch   7 completed.
Epoch   8 completed.
Epoch   9 completed.
Epoch  10 completed.
Epoch  11 completed.
Epoch  12 completed.
Epoch  13 completed.
Epoch  14 completed.
Epoch  15 completed.
Epoch  16 completed.
Epoch  17 completed.
Epoch  18 completed.
Epoch  19 completed.
Epoch  20 completed.
Epoch  21 completed.
Epoch  22 completed.
Epoch  23 completed.
Epoch  24 completed.
Epoch  25 completed.
Epoch  26 completed.
Epoch  27 completed.
Epoch  28 completed.
Epoch  29 completed.
Epoch  30 completed.
Epoch  31 completed.
Epoch  32 completed.
Epoch  33 completed.
Epoch  34 completed.
Epoch  35 completed.
Epoch  36 completed.
Epoch  37 completed.
Epoch  38 completed.
Epoch  39 completed.
Epoch  40 completed.
Epoch  41 completed.
Epoch  42 completed.
Epoch  43 completed.
Epoch  44 completed.
Epoch  45 completed.
Epoch  46 completed.
Epoch  47 com