In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader


In [2]:
BATCH_SIZE = 64

training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

In [3]:
train_loader = DataLoader(
    dataset=training_data,
    batch_size=BATCH_SIZE,
    shuffle = False,
)
test_loader = DataLoader(
    dataset=datasets.MNIST(root="data", train=False, transform=ToTensor()),
    batch_size=64,
    shuffle=False,
)

In [5]:
training_data[0][0].shape

torch.Size([1, 28, 28])

In [6]:
class Net(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Net, self).__init__()
        self.l1 = nn.Linear(input_size, hidden_size)
        self.l2 = nn.Linear(hidden_size, output_size)

    def forward(self, x ):
        x = F.relu(self.l1(x))
        x = F.log_softmax(self.l2(x))
        return x


In [7]:
model = Net(28*28, 64, 10)
NUM_EPOCHS = 10

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for epoch in range(NUM_EPOCHS):
    for i, (images, labels) in enumerate(train_loader):
        images = images.view(-1, 28 * 28)
        outputs = model(images)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if (i+1) % 100 == 0:
            print(f'epoch: {epoch} + loss: {loss}')

with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images.view(-1, 28 * 28)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        print(
            "Accuracy of the model on the {} test images: {} %".format(
                total, 100 * correct / total
            ),
        )

  x = F.log_softmax(self.l2(x))


epoch: 0 + loss: 0.4303024709224701
epoch: 0 + loss: 0.47723910212516785
epoch: 0 + loss: 0.3501252233982086
epoch: 0 + loss: 0.3403085768222809
epoch: 0 + loss: 0.2943264842033386
epoch: 0 + loss: 0.3741249144077301
epoch: 0 + loss: 0.24318695068359375
epoch: 0 + loss: 0.29617130756378174
epoch: 0 + loss: 0.164317324757576
epoch: 1 + loss: 0.1255573332309723
epoch: 1 + loss: 0.2800259590148926
epoch: 1 + loss: 0.1934867948293686
