In [109]:
import numpy as np
import pandas as pd
import torch
from torch import autograd
from torch.utils.data import TensorDataset, DataLoader
import torchvision as tv
from sklearn.metrics import mean_squared_error, accuracy_score
import matplotlib.pyplot as plt

In [110]:
BATCH_SIZE = 256
def show_pic(pic):
    plt.imshow(pic.numpy().reshape(28,28), cmap='gray')

In [111]:
train_dataset = tv.datasets.FashionMNIST('.', train=True, transform=tv.transforms.ToTensor(), download=True)
test_dataset = tv.datasets.FashionMNIST('.', train=False, transform=tv.transforms.ToTensor(), download=True)
train = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE)

In [143]:
nn = torch.nn.Sequential(
    torch.nn.Flatten(),
    torch.nn.Linear(784, 512),
    torch.nn.BatchNorm1d(512),
    torch.nn.GELU(),
    torch.nn.Linear(512, 256),
    torch.nn.BatchNorm1d(256),
    torch.nn.GELU(),
    torch.nn.Linear(256, 10)
)

loss_func = torch.nn.CrossEntropyLoss()
trainer = torch.optim.Adam(nn.parameters(), lr=0.001)

for epoch in range(1, 16):
    train_loss = 0
    batches_num = 0
    
    train_correctly_predicted = 0
    train_total = 0
    
    for X, y in train:
        trainer.zero_grad()
        y_pred = nn(X)
        train_l = loss_func(y_pred, y)
        train_l.backward()
        trainer.step()
        
        train_loss += train_l
        batches_num += 1
        
        train_total += len(X)
        train_correctly_predicted += (y_pred.argmax(dim=1) == y).sum()
    
    train_loss /= batches_num
    train_accuracy = train_correctly_predicted / train_total
    
    test_pred_y = nn(test_dataset.data.to(torch.float32)).detach()
    test_true_y = test_dataset.targets
    
    test_loss = loss_func(test_pred_y, test_true_y)
    test_accuracy = accuracy_score(test_pred_y.argmax(dim=1), test_true_y)
    
    print(f"Epoch No.{epoch}, train loss: {train_loss}, train acc: {train_accuracy}, test loss: {test_loss}, test acc: {test_accuracy}")

print('Switched from Adam to SGD')
    
trainer = torch.optim.SGD(nn.parameters(), lr=0.01)

for epoch in range(16, 41):
    train_loss = 0
    batches_num = 0
    
    train_correctly_predicted = 0
    train_total = 0
    
    for X, y in train:
        trainer.zero_grad()
        y_pred = nn(X)
        train_l = loss_func(y_pred, y)
        train_l.backward()
        trainer.step()
        
        train_loss += train_l
        batches_num += 1
        
        train_total += len(X)
        train_correctly_predicted += (y_pred.argmax(dim=1) == y).sum()
    
    train_loss /= batches_num
    train_accuracy = train_correctly_predicted / train_total
    
    test_pred_y = nn(test_dataset.data.to(torch.float32)).detach()
    test_true_y = test_dataset.targets
    
    test_loss = loss_func(test_pred_y, test_true_y)
    test_accuracy = accuracy_score(test_pred_y.argmax(dim=1), test_true_y)
    
    print(f"Epoch No.{epoch}, train loss: {train_loss}, train acc: {train_accuracy}, test loss: {test_loss}, test acc: {test_accuracy}")


Epoch No.1, train loss: 0.433793306350708, train acc: 0.8463666439056396, test loss: 0.37361353635787964, test acc: 0.8651
Epoch No.2, train loss: 0.31046488881111145, train acc: 0.8860166668891907, test loss: 0.34683990478515625, test acc: 0.8736
Epoch No.3, train loss: 0.2689378261566162, train acc: 0.9006666541099548, test loss: 0.33656376600265503, test acc: 0.8776
Epoch No.4, train loss: 0.23747779428958893, train acc: 0.9125000238418579, test loss: 0.3339919447898865, test acc: 0.8815
Epoch No.5, train loss: 0.21094511449337006, train acc: 0.9231500029563904, test loss: 0.33238887786865234, test acc: 0.8846
Epoch No.6, train loss: 0.18722575902938843, train acc: 0.9321833252906799, test loss: 0.3363695442676544, test acc: 0.8859
Epoch No.7, train loss: 0.16669557988643646, train acc: 0.9396166801452637, test loss: 0.3570007085800171, test acc: 0.8853
Epoch No.8, train loss: 0.1510605663061142, train acc: 0.9451833367347717, test loss: 0.3640090823173523, test acc: 0.8879
Epoch No

In [145]:
pred_y = nn(test_dataset.data.to(torch.float32)).argmax(dim=1)
true_y = test_dataset.targets
print('Final accuracy on test dataset:', accuracy_score(pred_y, true_y))

Final accuracy on test dataset: 0.9034
