In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch import optim
from torch.utils.data import DataLoader
import torchvision
import matplotlib.pyplot as plt


def plot_curve(data):
    fig = plt.figure()
    plt.plot(range(len(data)), data, color="blue")
    # 设置图列位置等
    plt.legend(["value"], loc="upper right")
    plt.xlabel("step")
    plt.ylabel("value")
    plt.show()

def plot_image(img, label, name):
    fig = plt.figure()
    for i in range(6):
        plt.subplot(2, 3, i+1)
        plt.tight_layout()
        plt.imshow(img[i][0]*0.3081+0.1307, cmap="gray", interpolation="none")
        plt.title("{}: {}".format(name, label[i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()

def one_hot(label, depth=10):
    out = torch.zeros(label.size(0), depth)
    idx = torch.LongTensor(label).view(-1, 1)
    out.scatter_(dim=1, index=idx, value=1)
    return out

# 设置数据批次大小
batch_size = 512

# 加载数据集
# shuffle=True打乱顺序，一般训练集需要而测试集不需要
train_loader = DataLoader(torchvision.datasets.MNIST('mnist_data/', train=True, download=True,transform=torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=(0.1307, ), std=(0.3081, ))
])), batch_size=batch_size, shuffle=True)

test_loader = DataLoader(torchvision.datasets.MNIST('mnist_data/', train=False, download=True,transform=torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=(0.1307, ), std=(0.3081, ))
])), batch_size=batch_size, shuffle=False)

x, y = next(iter(train_loader))
print(x.shape, y.shape, x.min(), x.max())
plot_image(x, y, 'image sample')

class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()

        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(64, 10)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        return x


net = Net()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
train_loss = []

for epoch in range(3):
    for batch_idx, (x, y) in enumerate(train_loader):
        x = x.view(x.size(), 784)
        out = net(x)
        y_onehot = one_hot(y)

        loss = F.mse_loss(out, y_onehot)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss.append(loss.item())

        if batch_idx % 10 == 0:
            print(epoch, batch_idx, loss.item())

plot_curve(train_loss)

total_correct = 0
for x, y in test_loader:
    x = x.view(x.size(0), 784)
    out = net(x)
    pred = out.argmax(dim=1)
    correct = pred.eq(y).sum().float().item()
    total_correct += correct

total_num = len(test_loader.dataset)
acc = total_correct / total_num
print("test acc:", acc)

x, y = next(iter(test_loader))
out = net(x.view(x.size(0)), 748)
# 参数最大值的位置
pred = out.argmax(dim=1)
plot_image(x, pred, "test")

In [3]:
import numpy


a = numpy.array([4, 43, 334, 12, 32])
a.max()
a.argmax()

2