<a href="https://colab.research.google.com/github/zhangfuyao/Google-colab/blob/classifier/class.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torchvision.datasets
from torch import nn
import torchvision.transforms as transforms
from tqdm import tqdm
from torch.utils.data import DataLoader

In [2]:
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

In [3]:
# 定义超参数
batch_size = 36
lr = 1e-3
epochs=10

In [10]:
train_data=torchvision.datasets.CIFAR10(root="./data",train=True,download=True,transform=transform)

train_dataloader=DataLoader(dataset=train_data,batch_size=batch_size,shuffle=True)

valid_data=torchvision.datasets.CIFAR10(root="./data",train=False,download=True,transform=transform)

valid_dataloader=DataLoader(dataset=valid_data,batch_size=batch_size,shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [11]:
test_data_iter = iter(valid_dataloader)  # 将testloader转换为迭代器
test_img, test_label = test_data_iter.next()  # 通过next（）获得一批数据

classes = ("plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck")

导入数据集

In [26]:
class LetNet(nn.Module):
  def __init__(self):
    super(LetNet,self).__init__()
    self.conv1=nn.Sequential(
        nn.Conv2d(3,16,5),
        nn.ReLU(),
        nn.MaxPool2d(2,2)
    )
    self.conv2=nn.Sequential(
        nn.Conv2d(16,32,5),
        nn.ReLU(),
        nn.MaxPool2d(2,2),
    )
    self.out=nn.Sequential(
        torch.nn.Linear(32*5*5,120),
        torch.nn.Linear(120,84),
        torch.nn.Linear(84,10)
    )
  def forward(self,x):
    x=self.conv1(x)
    x=self.conv2(x)
    x=x.view(-1,32*5*5)
    x=self.out(x)
    return x

In [27]:
# 开始实例化模型并定义损失函数和优化器
model=LetNet()
criterion=torch.nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(model.parameters(),lr=lr)

In [28]:
import matplotlib as plt
import numpy as np
def imshow(img):
    img = img / 2 + 0.5  # unnormalize反标准化过程input = output*0.5 + 0.5
    npimg = img.numpy()  # 转换为numpy
    plt.imshow(np.transpose(npimg, (1, 2, 0)))  # Pytorch内Tensor顺序[batch, channel, height, width],由于输入没有batch，故channel对于0，height对应1，width对应2
    # 此处要还原为载入图像时基础的shape，所以应把顺序变为[height, width, channel]， 所以需要np.transpose(npimg, (1, 2, 0))
    plt.show()

In [42]:
for epoch in range(1):
  running_loss = 0.0  # 这整个epoch的loss清零
  running_total = 0
  running_correct = 0
  for step, data in enumerate(train_dataloader, start=0):
    inputs,labels=data
    optimizer.zero_grad() # 梯度清零
    
    # 前向传播
    outputs=model(inputs)
    loss=criterion(outputs,labels)
    loss.backward()
    optimizer.step()
    running_loss += loss.item()
    predicted = torch.argmax(outputs.data, dim=1)
    running_total += inputs.shape[0]
    running_correct += (predicted == labels).sum().item()
    if step % 500==499:  # 不想要每一次都出loss，
        print('[%d, %5d]: loss: %.3f , acc: %.2f %%'
                  % (epoch + 1, step + 1, running_loss / 500, 100 * running_correct / running_total))

        running_loss = 0.0  # 这小批的loss清零
        running_total = 0
        running_correct = 0  # 这小批的acc清零
print("Training finished")
save_path = './path.pth'
torch.save(model.state_dict(), save_path)

[1,   500]: loss: 0.702 , acc: 75.27 %
[1,  1000]: loss: 0.715 , acc: 75.07 %
Training finished


In [46]:
import torch
import torchvision.transforms as transforms
from PIL import Image

transform = transforms.Compose(
    [
        transforms.Resize((32, 32)),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]  # 标准化 output = (input- 0.5)/0.5
)

classes = ("plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck")

net = LetNet()

net.load_state_dict(torch.load('path.pth'))  # 载入权重文件

im = Image.open('/content/test.png').convert("RGB")
im = transform(im)  # [C, H, W] 转成Pytorch的Tensor格式
im = torch.unsqueeze(im, dim=0)  # [N, C, H, W] 对数据增加一个新维度

with torch.no_grad():
    outputs = net(im)
    predict = torch.max(outputs, dim=1)[1].data.numpy()
print(classes[int(predict)])

plane
