In [12]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

## Loading dataset

In [13]:
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
) 


test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)


In [14]:
batch_size = 64
train_data_loader = DataLoader(training_data, batch_size=batch_size)
test_data_loader = DataLoader(test_data, batch_size=batch_size)

for x, y in test_data_loader:
    print(x.shape)
    print(y.shape)
    print(y.dtype)
    break


torch.Size([64, 1, 28, 28])
torch.Size([64])
torch.int64


## Creating model

In [15]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.stack(x)
        return logits

model = NeuralNetwork()
print(model)

NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
)


## Optimising the weights 

In [16]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

In [17]:
def train(data_loader, model, loss_fn, optimizer):
    size = len(data_loader.dataset)
    model.train()
    for batch, (x, y) in enumerate(data_loader):
        prediction = model(x)
        loss = loss_fn(prediction, y)
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            print(loss.item(), (batch + 1) * len(x), size)

def test(data_loader, model, loss_fn):
    size = len(data_loader.dataset)
    num_batches = len(data_loader)
    model.eval()
    test_loss, correct = 0, 0
    
    with torch.no_grad():
        for x, y in data_loader:
            pred = model(x)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(correct, test_loss)



## Train

In [18]:
EPOCH = 5
for t in range(1, EPOCH + 1):
    print(f'Epoch number {t}')
    train(
        train_data_loader,
        model,
        loss_fn,
        optimizer
    )
    test(
        test_data_loader,
        model,
        loss_fn
    )

Epoch number 1
2.30128812789917 64 60000


2.2812867164611816 6464 60000
2.2644267082214355 12864 60000
2.2592504024505615 19264 60000
2.234238624572754 25664 60000
2.208918571472168 32064 60000
2.218365430831909 38464 60000
2.178802967071533 44864 60000
2.175074338912964 51264 60000
2.143059253692627 57664 60000
0.4244 2.132374804490691
Epoch number 2
2.1429049968719482 64 60000
2.1257691383361816 6464 60000
2.0662953853607178 12864 60000
2.0893545150756836 19264 60000
2.032365322113037 25664 60000
1.9657355546951294 32064 60000
2.0042550563812256 38464 60000
1.913061499595642 44864 60000
1.9214746952056885 51264 60000
1.849424958229065 57664 60000
0.5528 1.843077358166883
Epoch number 3
1.8790148496627808 64 60000
1.837816834449768 6464 60000
1.7205452919006348 12864 60000
1.7768476009368896 19264 60000
1.676139235496521 25664 60000
1.6176562309265137 32064 60000
1.6581358909606934 38464 60000
1.5526419878005981 44864 60000
1.581667423248291 51264 60000
1.4815558195114136 57664 60000
0.607 1.4953518756635629
Epoch number 4
1.

## Saving on disk

In [20]:
torch.save(model.state_dict(), "model.pth")

In [22]:
classes = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]
model.eval()
x, y = test_data[0][0], test_data[0][1]
with torch.no_grad():
    pred = model(x)
    predicted, actual = classes[pred[0].argmax(0)], classes[y]
    print(f'Predicted: {predicted}, Actual: {actual}')


Predicted: Ankle boot, Actual: Ankle boot
