In [None]:
import torch
from torch.utils.data import DataLoader
from torch import nn

import numpy as np

import torchvision.datasets as datasets
from torchvision.transforms import ToTensor

# mnist_train[4][0] is a PIL image. Standard python image library. But our neural network or pytorch setup is gonna expect a tensor and is not gonna be able to deal with a PIL image. So we need to convert that PIL image into a tensor.

# here this mnist_train does not support something like mnist_train[0: 32] to get a batch of images. So we have to use DataLoader for that. This operation is not implemented
mnist_train = datasets.MNIST(root='./data', download=True, train=True, transform=ToTensor())
mnist_test = datasets.MNIST(root='./data', download=True, train=False, transform=ToTensor())


# but then we have to do this for all of the images in the dataset. So instead of doing it manually, we can use the transform argument when we create the dataset object. We can pass in a transform function that will be applied to each image when it is loaded from the dataset. In this case, we are using the ToTensor() transform from torchvision.transforms, which converts a PIL image to a tensor.
# print(torch.tensor(mnist_train[4][0]))
# print(torch.tensor(mnist_train[4][0]).shape)
# print(torch.tensor(np.array(mnist_train[4][0])))

# So now here first thing is tensor not PIL image. Second thing is that the tensor has shape [1, 28, 28]. The first dimension is the number of channels. Since MNIST images are grayscale, there is only one channel. The next two dimensions are the height and width of the image, which are both 28 pixels. Image data is normalized to be between 0 and 1. So the pixel values are floats between 0 and 1. Because it is a bit difficult for neural networks to work with integer pixel values between 0 and 255. It is possible but a bit more difficult.
print(mnist_train[0])

(tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000

In [8]:
train_dataloader = DataLoader(mnist_train, batch_size=32, shuffle=True)
test_dataloader = DataLoader(mnist_test, batch_size=32, shuffle=True)

# train_dataloader = DataLoader(mnist_train, batch_size=32, shuffle=True)
# test_dataloader = DataLoader(mnist_test, batch_size=32, shuffle=True)

In [15]:
for X, y in test_dataloader:
    print(X)
    print(y)
    print(X.shape)
    print(y.shape)
    break

# for X, y in train_dataloader:
#     print(X.shape)
#     print(y.shape)
#     break

tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        ...,


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0.