In [1]:
import os

import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor


class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten: nn.Flatten = nn.Flatten()
        self.linear_relu_stack: nn.Sequential = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logistics = self.linear_relu_stack(x)
        return logistics


class NeuralNetworkManager:
    def __init__(self):
        self.device: str = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Using {self.device} device")
        self.model: NeuralNetwork = NeuralNetwork().to(self.device)
        self.loss_fn: nn.CrossEntropyLoss = nn.CrossEntropyLoss()
        self.optimizer = optim.SGD(self.model.parameters(), lr=1e-3)

    def train(self, data_loader: DataLoader):
        size: int = len(data_loader.dataset)
        self.model.train()
        for batch, (X, y) in enumerate(data_loader):
            X, y = X.to(self.device), y.to(self.device)

            pred = self.model(X)
            loss = self.loss_fn(pred, y)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            if batch % 100 == 0:
                loss, current = loss.item(), batch * len(X)
                print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")

    def test(self, data_loader: DataLoader):
        size: int = len(data_loader.dataset)
        num_batches: int = len(data_loader)
        self.model.eval()
        test_loss, correct = 0, 0
        with torch.no_grad():
            for X, y in data_loader:
                X, y = X.to(self.device), y.to(self.device)
                pred = self.model(X)
                test_loss += self.loss_fn(pred, y).item()
                correct += (pred.argmax(1) == y).type(torch.float).sum().item()
        test_loss /= num_batches
        correct /= size
        print(f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

    def save_parameters(self, filepath):
        torch.save(self.model.state_dict(), filepath)

    def load_parameters(self, filepath):
        self.model.load_state_dict(torch.load(filepath))
        self.model.eval()

In [2]:
def run_training_model():
    model_manager: NeuralNetworkManager = NeuralNetworkManager()
    epochs = 5
    for t in range(epochs):
        print(f"Epoch {t + 1}\n-------------------------------")
        model_manager.train(data_loader=train_data_loader)
        model_manager.test(data_loader=test_data_loader)
    model_manager.save_parameters(model_params_path)


def run_loading_model():
    model_manager: NeuralNetworkManager = NeuralNetworkManager()
    model_manager.load_parameters(model_params_path)
    model_manager.test(data_loader=test_data_loader)

In [3]:
if __name__ == '__main__':
    data_path: str = os.path.join("resources", "datasets", "fashion-mnist")
    model_params_path: str = os.path.join("resources", "models", "nn-fashion-mnist.pth")
    os.makedirs(data_path, exist_ok=True)
    os.makedirs(os.path.dirname(model_params_path), exist_ok=True)

    training_data: datasets.FashionMNIST = datasets.FashionMNIST(
        root=data_path,
        train=True,
        download=True,
        transform=ToTensor(),
    )

    test_data: datasets.FashionMNIST = datasets.FashionMNIST(
        root=data_path,
        train=False,
        download=True,
        transform=ToTensor(),
    )

    train_data_loader = DataLoader(training_data, batch_size=64)
    test_data_loader = DataLoader(test_data, batch_size=64)

    for X, y in test_data_loader:
        print(f"Shape of X [N, C, H, W] {X.shape}")
        print(f"Shape of y {y.shape}, {y.dtype}")
        break

    run_training_model()
    run_loading_model()

Shape of X [N, C, H, W] torch.Size([64, 1, 28, 28])
Shape of y torch.Size([64]), torch.int64
Using cpu device
Epoch 1
-------------------------------
loss: 2.296741 [    0/60000]
loss: 2.289731 [ 6400/60000]
loss: 2.266238 [12800/60000]
loss: 2.264049 [19200/60000]
loss: 2.243844 [25600/60000]
loss: 2.208752 [32000/60000]
loss: 2.215708 [38400/60000]
loss: 2.176661 [44800/60000]
loss: 2.178626 [51200/60000]
loss: 2.137944 [57600/60000]
Test Error: 
 Accuracy: 44.4%, Avg loss: 2.140388 

Epoch 2
-------------------------------
loss: 2.152412 [    0/60000]
loss: 2.147752 [ 6400/60000]
loss: 2.081513 [12800/60000]
loss: 2.102163 [19200/60000]
loss: 2.049279 [25600/60000]
loss: 1.980403 [32000/60000]
loss: 2.009014 [38400/60000]
loss: 1.924953 [44800/60000]
loss: 1.941730 [51200/60000]
loss: 1.850578 [57600/60000]
Test Error: 
 Accuracy: 58.7%, Avg loss: 1.860476 

Epoch 3
-------------------------------
loss: 1.897690 [    0/60000]
loss: 1.870596 [ 6400/60000]
loss: 1.743535 [12800/60000]