# fully connected net

In [1]:
import torch
from torch import nn
from torch.utils import data
from torchvision import datasets, transforms

## 1.load data

利用torchvision.datasets加载mnist数据集<br>
只需download一次

In [2]:
train_set = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_set = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())

利用torch.utils.data.DataLoader形成可迭代数据

In [3]:
batch_size = 128
train_iter = data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=4)
test_iter = data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2)

## 2.model

定义模型

In [4]:
net = nn.Sequential(nn.Flatten(),
                    nn.Linear(784, 128),
                    nn.ReLU(),
                    nn.Linear(128, 10))

定义策略（损失函数）和算法（优化器）

In [5]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.1)

## 3. training

计算模型准确性的函数

In [6]:
def accuracy(data_iter, model):
    num = 0
    accurate_num = 0
    for x, y in data_iter:
        y_pred = model(x)
        accurate_num += torch.sum(torch.argmax(y_pred, dim=1) == y).item()
        num += y.shape[0]
    return accurate_num / num 

训练

In [7]:
num_epochs = 100
for epoch in range(num_epochs):
    l = 0.0
    for x, y in train_iter:
        y_pred = net(x)
        loss = loss_fn(y_pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        l += loss
    if (epoch + 1) % 10 == 0:
        print(f"epoch {epoch + 1} loss: {l}, train accuracy: {accuracy(train_iter, net)}, test accuracy: {accuracy(test_iter, net)}")

epoch 10 loss: 50.840240478515625, train accuracy: 0.97255, test accuracy: 0.9669
epoch 20 loss: 26.60177993774414, train accuracy: 0.9862166666666666, test accuracy: 0.9764
epoch 30 loss: 16.633323669433594, train accuracy: 0.9918666666666667, test accuracy: 0.9784
epoch 40 loss: 11.172441482543945, train accuracy: 0.9962833333333333, test accuracy: 0.9785
epoch 50 loss: 7.783447265625, train accuracy: 0.99825, test accuracy: 0.9788
epoch 60 loss: 5.630319595336914, train accuracy: 0.9989333333333333, test accuracy: 0.9791
epoch 70 loss: 4.1936116218566895, train accuracy: 0.9995333333333334, test accuracy: 0.9804
epoch 80 loss: 3.2499313354492188, train accuracy: 0.9997833333333334, test accuracy: 0.9803
epoch 90 loss: 2.60109281539917, train accuracy: 0.99985, test accuracy: 0.9799
epoch 100 loss: 2.1415786743164062, train accuracy: 0.9999333333333333, test accuracy: 0.9804


## 4.testing

In [8]:
print(accuracy(test_iter, net))

0.9804
