In [1]:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [2]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
fashion_mnist_train = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
fashion_mnist_test = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

trainloader = DataLoader(fashion_mnist_train, batch_size=64, shuffle=True)
testloader = DataLoader(fashion_mnist_test, batch_size=64, shuffle=False)


Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26.4M/26.4M [00:02<00:00, 11.1MB/s]


Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29.5k/29.5k [00:00<00:00, 170kB/s]


Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4.42M/4.42M [00:01<00:00, 3.24MB/s]


Extracting ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5.15k/5.15k [00:00<00:00, 15.0MB/s]


Extracting ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw



In [3]:
model = nn.Sequential(
    nn.Conv2d(1, 8, kernel_size=3, stride=1, padding=1),
    nn.ReLU(),
    nn.Conv2d(8, 8, kernel_size=3, stride=1, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1),
    nn.ReLU(),
    nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2, 2),
    nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
    nn.ReLU(),
    nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
    nn.ReLU(),
    nn.AdaptiveAvgPool2d((1, 1)),
    nn.Flatten(),
    nn.Linear(32, 128),
    nn.ReLU(),
    nn.Linear(128, 10)
)



In [5]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 77
for epoch in range(num_epochs):
    running_loss = 0.0
    all_outputs = []
    for inputs, labels in trainloader:
        optimizer.zero_grad()

        outputs = model(inputs)

        all_outputs.append(outputs.detach())

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()


    all_outputs = torch.cat(all_outputs, dim=0)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(trainloader):.4f}")



correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in testloader:
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Accuracy on the test dataset: {100 * correct / total:.2f}%")



Epoch [1/77], Loss: 0.1276
Epoch [2/77], Loss: 0.1242
Epoch [3/77], Loss: 0.1229
Epoch [4/77], Loss: 0.1204
Epoch [5/77], Loss: 0.1196
Epoch [6/77], Loss: 0.1162
Epoch [7/77], Loss: 0.1156
Epoch [8/77], Loss: 0.1135
Epoch [9/77], Loss: 0.1105
Epoch [10/77], Loss: 0.1109
Epoch [11/77], Loss: 0.1058
Epoch [12/77], Loss: 0.1082
Epoch [13/77], Loss: 0.1035
Epoch [14/77], Loss: 0.1033
Epoch [15/77], Loss: 0.1016
Epoch [16/77], Loss: 0.0984
Epoch [17/77], Loss: 0.0976
Epoch [18/77], Loss: 0.0965
Epoch [19/77], Loss: 0.0975
Epoch [20/77], Loss: 0.0925
Epoch [21/77], Loss: 0.0898
Epoch [22/77], Loss: 0.0923
Epoch [23/77], Loss: 0.0864
Epoch [24/77], Loss: 0.0896
Epoch [25/77], Loss: 0.0860
Epoch [26/77], Loss: 0.0853
Epoch [27/77], Loss: 0.0823
Epoch [28/77], Loss: 0.0817
Epoch [29/77], Loss: 0.0847
Epoch [30/77], Loss: 0.0811
Epoch [31/77], Loss: 0.0792
Epoch [32/77], Loss: 0.0779
Epoch [33/77], Loss: 0.0761
Epoch [34/77], Loss: 0.0789
Epoch [35/77], Loss: 0.0758
Epoch [36/77], Loss: 0.0794
E