In [6]:
from d2l import torch as d2l

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)

In [19]:
from matplotlib import pyplot as plt
from torch import nn
import torch
from d2l import torch as d2l

X = torch.rand(size=(1,1,28,28))

class MyNet(nn.Sequential):
    def __init__(self):
        super().__init__(nn.Conv2d(1,6, kernel_size=5, padding=2), nn.ReLU(),  
                         nn.AvgPool2d(kernel_size=2, stride=2),                
                         nn.Conv2d(6,16, kernel_size=5, padding=2), nn.ReLU(),
                         nn.AvgPool2d(kernel_size=2, stride=2),
                         nn.Flatten(),
                         nn.Linear(784, 256),
                         nn.ReLU(),
                         nn.Linear(256,96),
                         nn.ReLU(),
                         nn.Linear(96,10))


        

def accuracy(net, X, y):
    if isinstance(net, nn.Module):
        net.eval()
    return (net(X).argmax(axis=1) == y).sum()

def evaluate_accuracy(net, data_iter, device):
    if isinstance(net, nn.Module):
        net.eval()
    metric = d2l.Accumulator(2)
    for X, y in data_iter:
        X = X.to(device)
        y = y.to(device)
        metric.add(accuracy(net, X, y), y.numel())
    return metric[0] / metric[1]
        


def train_epoch(net, data_iter, loss, updater, device):
    metric = d2l.Accumulator(3)
    if isinstance(net, nn.Module):
        net.train()
        net = net.to(device)
        loss.to(device)
    for X, y in data_iter:
        X = X.to(device)
        y = y.to(device)
        updater.zero_grad()
        l = loss(net(X), y)
        l.backward()
        updater.step()
        acc = accuracy(net, X, y)
        metric.add(l,acc, y.numel())
    return metric[0] / metric[2], metric[1] / metric[2] 
    
        

config = {
    'num_epoch': 30,
    'learning_rate': 1e-2
}

net = MyNet()

def init_weights(m):
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        nn.init.xavier_normal_(m.weight)
        nn.init.zeros_(m.bias)

def train(net, train_iter, test_iter, loss, device):
    net.apply(init_weights)
    updater = torch.optim.SGD(params=net.parameters() ,lr=config['learning_rate'])
    num_epoch = config['num_epoch']
    for epoch in range(num_epoch):
        train_loss, train_acc = train_epoch(net, train_iter, loss, updater, device)
        test_acc = evaluate_accuracy(net, test_iter, device)
        print(f'epoch {epoch} : train_loss {train_loss} train_acc {train_acc} / test_acc {test_acc}')
        
    
    
loss = nn.CrossEntropyLoss()

train(net, train_iter, test_iter, loss, torch.device('cuda:0'))
    
    
        

epoch 0 : train_loss 0.0069884165634711585 train_acc 0.37726666666666664 / test_acc 0.5925
epoch 1 : train_loss 0.0033265372504790625 train_acc 0.7045833333333333 / test_acc 0.7091
epoch 2 : train_loss 0.0028586305359999337 train_acc 0.7447333333333334 / test_acc 0.7387
epoch 3 : train_loss 0.002632752828796705 train_acc 0.7653166666666666 / test_acc 0.7412
epoch 4 : train_loss 0.0024763814533750217 train_acc 0.77975 / test_acc 0.7287
epoch 5 : train_loss 0.0023374227305253347 train_acc 0.7932166666666667 / test_acc 0.7693
epoch 6 : train_loss 0.002241135024527709 train_acc 0.8043833333333333 / test_acc 0.7885
epoch 7 : train_loss 0.0021593581770857177 train_acc 0.8131166666666667 / test_acc 0.7508
epoch 8 : train_loss 0.002075414804617564 train_acc 0.8209 / test_acc 0.7944
epoch 9 : train_loss 0.0020050736462076503 train_acc 0.8298333333333333 / test_acc 0.7868
epoch 10 : train_loss 0.0019484173640608787 train_acc 0.8330666666666666 / test_acc 0.8068
epoch 11 : train_loss 0.0018874552