In [None]:
import torch
import torchvision
from torch import nn
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l

In [None]:
trans = [transforms.ToTensor()]

trans = transforms.Compose(trans)
mnist_train = torchvision.datasets.FashionMNIST(
    root="../data", train=True, transform=trans, download=False)
mnist_test = torchvision.datasets.FashionMNIST(
    root="../data", train=False, transform=trans, download=False)

In [None]:
batch_size = 256
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size, shuffle=True,
                                        num_workers=8)
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size, shuffle=False,
                                        num_workers=8)

In [None]:
num_inputs = 784
num_outputs = 10
num_hiddens = 256

W1 = nn.Parameter(torch.randn(
    num_inputs, num_hiddens, requires_grad=True) * 0.01)
b1 = nn.Parameter(torch.zeros(num_hiddens, requires_grad=True))
W2 = nn.Parameter(torch.randn(
    num_hiddens, num_outputs, requires_grad=True) * 0.01)
b2 = nn.Parameter(torch.zeros(num_outputs, requires_grad=True))

params = [W1, b1, W2, b2]

In [None]:
def relu(X):
    zeros = torch.zeros_like(X)
    return torch.max(zeros, X)

def net(X):
    X = X.reshape(-1, num_inputs)
    H = relu(X@W1 + b1)
    return relu(H@W2 + b2)

In [None]:
def accuracy(y_pred, y):
    if len(y_pred.shape) > 1 and y_pred.shape[1] > 1:
        y_pred = y_pred.argmax(axis=1)
    cmp = y_pred.type(y.dtype) == y
    return float(cmp.type(y.dtype).sum())

def evaluate_accuracy(net, data_iter):
    if isinstance(net, torch.nn.Module):
        net.eval()
    metric = d2l.Accumulator(2) # count (1) num of accurate predictions and (2) total num of predictions
    with torch.no_grad():
        for X, y in data_iter:
            metric.add(accuracy(net(X), y), y.numel())
    return metric[0] / metric[1]

def train_epoch(net, train_iter, criterion, optimizer):
    if isinstance(net, nn.Module):
        net.train()
    # count (1) total training loss, (2) total training accuracy, and (3) num of samples
    metric = d2l.Accumulator(3)
    for X, y in train_iter:
        y_pred = net(X)
        loss = criterion(y_pred, y)
        
        optimizer.zero_grad()
        loss.mean().backward()
        optimizer.step()
        
        metric.add(float(loss.sum()), accuracy(y_pred, y), y.numel())
        
    return metric[0] / metric[2], metric[1] / metric[2]

def train(net, train_iter, test_iter, criterion, num_epochs, optimizer):
    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0, 1], legend=['train loss', 'train acc', 'test acc'],)
    
    for epoch in range(num_epochs):
        train_metrics = train_epoch(net, train_iter, criterion, optimizer)
        test_acc = evaluate_accuracy(net, test_iter)
        animator.add(epoch + 1, train_metrics + (test_acc, ))

In [None]:
criterion = nn.CrossEntropyLoss(reduction='none')

In [None]:
epochs = 10
lr = 0.1
optimizer = torch.optim.SGD(params, lr=lr)

train(net, train_iter, test_iter, criterion, epochs, optimizer)

In [None]:
def get_fashion_mnist_labels(labels):
    text_labels = [
        't-shirt', 'trouser', 'pullover', 'dress', 'coat',
        'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot' ]
    return [text_labels[int(i)] for i in labels]

def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        if torch.is_tensor(img):
            ax.imshow(img.numpy())
        else:
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if titles:
            ax.set_title(titles[i])

In [None]:
def predict(net, test_oter, n=8):
    for X, y in test_iter:
        break
    true_labels = get_fashion_mnist_labels(y)
    pred_labels = get_fashion_mnist_labels(net(X).argmax(axis=1))
    titles = [true + '\n' + pred for true, pred in zip(true_labels, pred_labels)]
    
    show_images(
        X[0:n].reshape(-1, 28, 28), 1, n, titles=titles[0:n]    
    )
    
predict(net, test_iter, 16)