In [12]:
import torch
import torchvision

# Preparing the dataset
n_epochs = 3
batch_size_train = 64
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5
log_interval = 10

## For repeatable experiments we have to 
## set random seeds for anything using 
## random number generation - this means 
## numpy and random as well!
random_seed = 1
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)


<torch._C.Generator at 0x7fecc89cc050>

#### Why use Torch Vision?

Now we'll also need DataLoaders for the dataset. This is where TorchVision comes into play. It let's use load the MNIST dataset in a handy way. 

TorchVision offers a lot of handy transformations, such as cropping or normalization. 

In [15]:
# The values 0.1307 and 0.3081 used for 
# the Normalize() transformation below are 
# the global mean and standard deviation of 
# the MNIST dataset, we'll take them as a given here.

train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,))
                               ])),
    batch_size=batch_size_train, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('', train=False, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081,))
                               ])),
    batch_size=batch_size_test, shuffle=True)


In [18]:
#Ok let's see what one test data batch consists of.
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)
example_data.shape


torch.Size([1000, 1, 28, 28])