In [1]:
import torch
import numpy as np
from torch import nn
import torch.nn.functional as F
from torchvision import datasets, transforms

In [2]:
# params
batch_size = 256
learning_rate = 1e-3
epochs = 10

In [3]:
# transform
trans = transforms.ToTensor()

In [4]:
# initialize the train and test datasets
train_loader = torch.utils.data.DataLoader(datasets.FashionMNIST(root='data', 
                                                                 train=True, 
                                                                 transform=trans, 
                                                                 download=True), 
                                           batch_size=batch_size, 
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(datasets.FashionMNIST(root='data', 
                                                                train=False, 
                                                                transform=trans, 
                                                                download=True), 
                                          batch_size=batch_size)

In [5]:
# network set
w1, b1 = torch.randn(200,784,requires_grad=True),\
        torch.zeros(200,requires_grad=True)
w2, b2 = torch.randn(200,200,requires_grad=True),\
        torch.zeros(200, requires_grad=True)
w3, b3 = torch.randn(10,200,requires_grad=True),\
        torch.zeros(10,requires_grad=True)

# init
torch.nn.init.kaiming_normal_(w1)
torch.nn.init.kaiming_normal_(w2)
torch.nn.init.kaiming_normal_(w3)

tensor([[ 0.1372, -0.0885,  0.0172,  ...,  0.1417, -0.0181,  0.0783],
        [ 0.0340, -0.1043,  0.0035,  ..., -0.0964, -0.0462,  0.0486],
        [ 0.1324,  0.0727,  0.2301,  ..., -0.0340,  0.0664, -0.0265],
        ...,
        [-0.0102,  0.0315,  0.1673,  ..., -0.0608, -0.1819,  0.0796],
        [-0.0459,  0.1167,  0.0115,  ...,  0.0491, -0.0645,  0.0405],
        [-0.1204,  0.0028, -0.1266,  ...,  0.0795, -0.0147,  0.1155]],
       requires_grad=True)

In [6]:
# forward
def forward(x):
    x = x@w1.T + b1
    x = F.relu(x)
    
    x = x@w2.T + b2
    x = F.relu(x)
    
    x = x@w3.T + b3
    x = F.softmax(x)
    return x

In [8]:
# optimizer
optimizer = torch.optim.SGD([w1,w2,w3,b1,b2,b3], lr=learning_rate)

# loss function
loss_fn = nn.CrossEntropyLoss()

In [14]:
# training
for epoch in range(epochs):
    # forward
    for batch_idx, (data, target) in enumerate(train_loader):
        data = data.reshape(-1, 28*28)
        y_hat = forward(data)
        
        loss = loss_fn(y_hat, target)
        # backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if batch_idx%100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()))
    
    # test
    test_loss = 0
    correct = 0
    
    for data, target in test_loader:
        data = data.reshape(-1,28*28)
        
        logits = forward(data)
        test_loss += loss_fn(logits, target).item()
        
        pred = logits.data.max(1)[1]
        correct += pred.eq(target.data).sum()
    
    test_loss /= len(test_loader.dataset)
    
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

  x = F.softmax(x)



Test set: Average loss: 0.0091, Accuracy: 2270/10000 (23%)


Test set: Average loss: 0.0090, Accuracy: 2623/10000 (26%)


Test set: Average loss: 0.0090, Accuracy: 2850/10000 (28%)


Test set: Average loss: 0.0089, Accuracy: 3091/10000 (31%)


Test set: Average loss: 0.0088, Accuracy: 3362/10000 (34%)


Test set: Average loss: 0.0088, Accuracy: 3611/10000 (36%)


Test set: Average loss: 0.0087, Accuracy: 3826/10000 (38%)


Test set: Average loss: 0.0086, Accuracy: 4052/10000 (41%)


Test set: Average loss: 0.0086, Accuracy: 4189/10000 (42%)


Test set: Average loss: 0.0085, Accuracy: 4290/10000 (43%)

