<a href="https://colab.research.google.com/github/zhangfuyao/Google-colab/blob/classifier/aiMNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torchvision.datasets
from torch import nn
import torchvision.transforms as transforms
from tqdm import tqdm
from torch.utils.data import DataLoader

In [None]:
# 定义超参数 
EPOCH = 10  # 训练整批数据多少次
BATCH_SIZE = 64
LR = 0.01  # 学习率
momentum = 0.5

In [None]:
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_data= torchvision.datasets.MNIST(root="./mnist",train=True,download=True,transform=transform)
test_data= torchvision.datasets.MNIST(root="./mnist",train=False,download=True,transform=transform)

train_loader=DataLoader(train_data,batch_size=BATCH_SIZE,shuffle=True)
test_loader=DataLoader(test_data,batch_size=BATCH_SIZE,shuffle=False)

In [None]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 10, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(
                kernel_size=2
            ),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(10, 20, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.out = torch.nn.Sequential(
            torch.nn.Linear(320, 50),
            torch.nn.Linear(50, 10),
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)
        x = self.out(x)
        return x

In [None]:
model = CNN()  # 实例化模型
criterion = torch.nn.CrossEntropyLoss()  # 交叉熵损失函数
optimizer = torch.optim.SGD(model.parameters(), lr=LR, momentum=momentum)

In [None]:
def train(epoch):
    running_loss = 0.0  # 这整个epoch的loss清零
    running_total = 0
    running_correct = 0
    program_bar = tqdm(total=len(train_loader), leave=False)
    for batch_idx, data in enumerate(train_loader, 0):
        inputs, target = data
        optimizer.zero_grad()  # 清零

        # 前向传播
        outputs = model(inputs)
        loss = criterion(outputs, target)
        loss.backward()
        optimizer.step()
        program_bar.update()

        running_loss += loss.item()
        predicted = torch.argmax(outputs.data, dim=1)
        running_total += inputs.shape[0]
        running_correct += (predicted == target).sum().item()
        if batch_idx % 300==299:  # 不想要每一次都出loss，浪费时间，选择每300次出一个平均损失,和准确率
            print('[%d, %5d]: loss: %.3f , acc: %.2f %%'
                  % (epoch + 1, batch_idx + 1, running_loss / 300, 100 * running_correct / running_total))

            running_loss = 0.0  # 这小批的loss清零
            running_total = 0
            running_correct = 0  # 这小批的acc清零

In [None]:
if __name__ == '__main__':
    train(EPOCH)


  0%|          | 0/938 [00:00<?, ?it/s][A
  0%|          | 4/938 [00:00<00:29, 32.00it/s][A
  1%|          | 8/938 [00:00<00:29, 31.85it/s][A
  1%|▏         | 12/938 [00:00<00:29, 31.40it/s][A
  2%|▏         | 16/938 [00:00<00:29, 31.71it/s][A
  2%|▏         | 20/938 [00:00<00:29, 31.35it/s][A
  3%|▎         | 24/938 [00:00<00:29, 30.91it/s][A
  3%|▎         | 28/938 [00:00<00:30, 29.79it/s][A
  3%|▎         | 32/938 [00:01<00:30, 30.13it/s][A
  4%|▍         | 36/938 [00:01<00:29, 30.36it/s][A
  4%|▍         | 40/938 [00:01<00:29, 30.56it/s][A
  5%|▍         | 44/938 [00:01<00:29, 30.22it/s][A
  5%|▌         | 48/938 [00:01<00:30, 29.22it/s][A
  6%|▌         | 52/938 [00:01<00:29, 29.96it/s][A
  6%|▌         | 56/938 [00:01<00:28, 30.46it/s][A
  6%|▋         | 60/938 [00:01<00:29, 29.94it/s][A
  7%|▋         | 64/938 [00:02<00:29, 30.11it/s][A
  7%|▋         | 68/938 [00:02<00:29, 29.73it/s][A
  8%|▊         | 71/938 [00:02<00:29, 29.57it/s][A
  8%|▊         | 74/93

[11,   300]: loss: 0.131 , acc: 96.21 %



 33%|███▎      | 308/938 [00:10<00:22, 28.59it/s][A
 33%|███▎      | 312/938 [00:10<00:21, 29.50it/s][A
 34%|███▎      | 315/938 [00:10<00:21, 28.97it/s][A
 34%|███▍      | 318/938 [00:10<00:21, 29.11it/s][A
 34%|███▍      | 321/938 [00:11<00:21, 29.05it/s][A
 35%|███▍      | 324/938 [00:11<00:21, 28.08it/s][A
 35%|███▍      | 327/938 [00:11<00:22, 27.58it/s][A
 35%|███▌      | 330/938 [00:11<00:21, 28.15it/s][A
 36%|███▌      | 333/938 [00:11<00:21, 28.38it/s][A
 36%|███▌      | 336/938 [00:11<00:21, 28.33it/s][A
 36%|███▌      | 340/938 [00:11<00:20, 29.00it/s][A
 37%|███▋      | 343/938 [00:11<00:20, 29.03it/s][A
 37%|███▋      | 347/938 [00:11<00:19, 29.75it/s][A
 37%|███▋      | 350/938 [00:12<00:19, 29.62it/s][A
 38%|███▊      | 353/938 [00:12<00:20, 28.71it/s][A
 38%|███▊      | 357/938 [00:12<00:19, 29.60it/s][A
 38%|███▊      | 360/938 [00:12<00:19, 29.25it/s][A
 39%|███▊      | 363/938 [00:12<00:19, 28.94it/s][A
 39%|███▉      | 366/938 [00:12<00:20, 28.47i

[11,   600]: loss: 0.111 , acc: 96.61 %



 65%|██████▍   | 607/938 [00:20<00:11, 28.92it/s][A
 65%|██████▌   | 610/938 [00:20<00:11, 29.05it/s][A
 65%|██████▌   | 613/938 [00:20<00:11, 29.17it/s][A
 66%|██████▌   | 616/938 [00:21<00:11, 28.67it/s][A
 66%|██████▌   | 619/938 [00:21<00:11, 28.83it/s][A
 66%|██████▋   | 623/938 [00:21<00:10, 29.72it/s][A
 67%|██████▋   | 626/938 [00:21<00:10, 29.60it/s][A
 67%|██████▋   | 629/938 [00:21<00:10, 29.28it/s][A
 67%|██████▋   | 632/938 [00:21<00:10, 29.17it/s][A
 68%|██████▊   | 635/938 [00:21<00:10, 29.39it/s][A
 68%|██████▊   | 638/938 [00:21<00:10, 29.49it/s][A
 68%|██████▊   | 642/938 [00:21<00:09, 30.36it/s][A
 69%|██████▉   | 646/938 [00:22<00:09, 30.44it/s][A
 69%|██████▉   | 650/938 [00:22<00:09, 30.87it/s][A
 70%|██████▉   | 654/938 [00:22<00:09, 31.15it/s][A
 70%|███████   | 658/938 [00:22<00:09, 30.04it/s][A
 71%|███████   | 662/938 [00:22<00:09, 30.21it/s][A
 71%|███████   | 666/938 [00:22<00:08, 30.30it/s][A
 71%|███████▏  | 670/938 [00:22<00:08, 29.83i

[11,   900]: loss: 0.097 , acc: 96.91 %



 97%|█████████▋| 907/938 [00:31<00:01, 28.66it/s][A
 97%|█████████▋| 910/938 [00:31<00:00, 29.00it/s][A
 97%|█████████▋| 914/938 [00:31<00:00, 29.90it/s][A
 98%|█████████▊| 917/938 [00:31<00:00, 28.33it/s][A
 98%|█████████▊| 921/938 [00:31<00:00, 29.14it/s][A
 99%|█████████▊| 925/938 [00:31<00:00, 29.81it/s][A
 99%|█████████▉| 928/938 [00:31<00:00, 29.71it/s][A
 99%|█████████▉| 931/938 [00:32<00:00, 29.11it/s][A
100%|█████████▉| 934/938 [00:32<00:00, 28.43it/s][A
100%|█████████▉| 937/938 [00:32<00:00, 28.17it/s][A
                                                 [A