In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
import torch.onnx as onnx
import torchvision.models as models
import torch.nn.functional as F
import torchvision
import cv2
from matplotlib import pyplot as plt 
from tensorboardX import SummaryWriter


training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True, # 如果本地有，则设为True
    transform=ToTensor()
)

test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

train_dataloader = DataLoader(training_data, batch_size=100, shuffle = True)
test_dataloader = DataLoader(test_data, batch_size=100, shuffle = True)

# 打印数据集的大小
train_size = len(train_dataloader.dataset)
test_size = len(test_dataloader.dataset)
print("train_size: %d"%train_size)
print("test_size: %d"%test_size)

# 显示数据集
images, lables = next(iter(train_dataloader))
img = torchvision.utils.make_grid(images, nrow = 10)
img = img.numpy().transpose(1, 2, 0)
cv2.imshow('img', img)
cv2.waitKey(0)

# 建立神经网络类
class NeuralNetwork(nn.Module):
    def __init__(self):
      super(NeuralNetwork, self).__init__()
      self.conv1 = nn.Conv2d(1, 6, 5, 1)
      self.conv2 = nn.Conv2d(6, 12, 3, 1)

      self.fc1 = nn.Linear(300, 120)
      self.fc2 = nn.Linear(120, 84)
      self.fc3 = nn.Linear(84, 10)

    # x represents our data
    def forward(self, x):
      # Pass data through conv1
      x = self.conv1(x)
      x = F.relu(x)
      x = F.max_pool2d(x, 2)

      x = self.conv2(x)
      x = F.relu(x)
      x = F.max_pool2d(x, 2)

      x = torch.flatten(x, 1)

  
      x = self.fc1(x)
      x = F.relu(x)

      x = self.fc2(x)
      x = F.relu(x)
      
      x = self.fc3(x)

      # Apply softmax to x 
      output = F.log_softmax(x, dim=1)
      return output

# 生成一个模型实例
model = NeuralNetwork()

# 设置超参数
learning_rate = 1e-3
batch_size = 100
epochs = 400

# 将模型的结构写入tensorboardx可视化文件
dummpy_input = torch.rand(1, 1, 28, 28)
with SummaryWriter('runs/NeuralNetwork/model') as w:
    w.add_graph(model, (dummpy_input,))

# 生成tensorboardx可视化文件，用于可视化训练和测试的loss, accuray
writer = SummaryWriter('runs/NeuralNetwork/data')

# 定义训练的loop
def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        pred = model(X)
        loss = loss_fn(pred, y)

        # backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        writer.add_scalar("train_loss",loss, t)

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f" loss: {loss:>7f} [{current:>5d}/{size:>5d}]")

# 定义测试的loop
def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    test_loss, correct = 0, 0
    epoch = 0
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    
    test_loss /= size
    correct /= size

    # 写入loss和accuray数据
    writer.add_scalar("test_loss", test_loss, t)
    writer.add_scalar("test_acc", correct*100, t)

    print(f" test error: \n accuracy: {(100*correct):>0.1f}%, avg loss: {test_loss:>8f}\n")

# 建立一个损失函数
loss_fn = nn.CrossEntropyLoss()
# 建立一个优化器
optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)

# 开始训练
for t in range(epochs):
    print(f" epoch {t + 1}\n---------------")
    train_loop(train_dataloader, model, loss_fn, optimizer)
    test_loop(test_dataloader, model, loss_fn)
print(" done !")
writer.close()
# 导出训练好的模型
torch.save(model.state_dict(), "cnn_model.pth")



In [None]:
# 此处为直接生成一个模型并加载训练好的模型参数
# model = NeuralNetwork()
# model.load_state_dict(torch.load("model.pth"))

# 当模型训练好后，进行推理，需要执行
# model.eval()，因为推理阶段与训练阶段
# 模型的执行情况会稍有不同
model.eval()

# 显示将要推理的数据并进行推理
for num in range(10):
  examples = enumerate(test_dataloader)
  batch_idx, (x, y) = next(examples)
  #关掉梯度计算，推理阶段不需要
  with torch.no_grad():
      pred = model(x)
  fig = plt.figure()
  for i in range(6):
    plt.subplot(2,3,i+1)
    plt.tight_layout()
    plt.imshow(x[i][0], cmap='gray', interpolation='none')
    plt.title("Prediction: {}".format(
      pred.data.max(1, keepdim=True)[1][i].item()))
    plt.xticks([])
    plt.yticks([])
  plt.show()
