In [1]:
import torch
import torchvision
import torchvision.transforms as transforms


In [2]:
def load_training_dataset():
    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5), (0.5))]
    )

    trainset = torchvision.datasets.MNIST(
        root="data", train=True, download=True, transform=transform
    )

    trainloader = torch.utils.data.DataLoader(
        trainset, batch_size=len(trainset), shuffle=True, num_workers=2
    )

    dataset_training_images, dataset_training_labels = next(iter(trainloader))

    return dataset_training_images, dataset_training_labels


dataset_training_images, dataset_training_labels = load_training_dataset()

print(dataset_training_images.size())
print(dataset_training_labels.size())


torch.Size([60000, 1, 28, 28])
torch.Size([60000])


In [9]:
class MNISTClassifier(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.input = torch.nn.Linear(in_features=28 * 28, out_features=64)
        self.hidden = torch.nn.Linear(in_features=64, out_features=64)
        self.output = torch.nn.Linear(in_features=64, out_features=10)

    def forward(self, tensor_images):
        layer_input = tensor_images.view(-1, 1 * 28 * 28)

        layer_input = torch.relu(self.input(layer_input))
        layer_input = torch.relu(self.hidden(layer_input))
        
        return self.output(layer_input)


mnist_classifier = MNISTClassifier()
print(MNISTClassifier())

MNISTClassifier(
  (input): Linear(in_features=784, out_features=64, bias=True)
  (hidden): Linear(in_features=64, out_features=64, bias=True)
  (output): Linear(in_features=64, out_features=10, bias=True)
)


In [13]:
loss = torch.nn.CrossEntropyLoss()
optimiser = torch.optim.Adam(mnist_classifier.parameters())
training_steps = 20


In [5]:
def divide_in_batches_32(tensor_dataset):
    number_samples = tensor_dataset.size()[0]

    step = 32

    list_batches_dataset = []
    for index in range(0, number_samples, step):
        new_batch = tensor_dataset[index : index + step]
        list_batches_dataset.append(new_batch)

    return list_batches_dataset


list_batches_images = divide_in_batches_32(dataset_training_images)
list_batches_labels = divide_in_batches_32(dataset_training_labels)


In [14]:
def train_classifier_batches(mnist_classifier, loss, optimiser, list_batches_images, list_batch_labels, number_training_steps):
    for _ in range(number_training_steps):

        running_loss = 0.0

        for batch_image, batch_label in zip(list_batches_images, list_batch_labels):
            optimiser.zero_grad()

            # Compute Loss
            estimator_predictions = mnist_classifier(batch_image)
            value_loss = loss.forward(input=estimator_predictions,
                                    target=batch_label)

            value_loss.backward()
            optimiser.step()

            running_loss += value_loss.item()

        running_loss = running_loss / len(list_batches_images)
        print("running loss:", running_loss)

train_classifier_batches(
    mnist_classifier,
    loss,
    optimiser,
    list_batches_images,
    list_batches_labels,
    training_steps,
)

running loss: 0.3721207421153784
running loss: 0.18768107621769112
running loss: 0.1447690345466137
running loss: 0.12002007986195386
running loss: 0.10544551107548178
running loss: 0.09261269581522792
running loss: 0.08417433995548636
running loss: 0.07941650625485927
running loss: 0.0717719434825393
running loss: 0.06743311271382652
running loss: 0.06157788112061098
running loss: 0.05899152673821276
running loss: 0.05393540739576177
running loss: 0.05420743654293086
running loss: 0.04910440432143708
running loss: 0.045843940859489764
running loss: 0.04462584466006568
running loss: 0.041397535623369425
running loss: 0.04207585004148229
running loss: 0.039324569922172425


In [15]:
index_image_dataset = 0

image = dataset_training_images[index_image_dataset].view(1, 1, 28, 28)
label = dataset_training_labels[index_image_dataset]

prediction = mnist_classifier(image)
print("prediction:", prediction)
print("label:", label.item())

argmax_prediction = torch.argmax(prediction).item()
print("argmax_prediction:", argmax_prediction)


prediction: tensor([[-13.8465,  -4.5318,  21.3872,   5.0450, -26.0116,  -8.9086, -21.4646,
          12.7147,  -2.8321, -23.7281]], grad_fn=<AddmmBackward0>)
label: 2
argmax_prediction: 2
