In [42]:
import torch
import torchvision
from torch.fx.experimental.partitioner_utils import Device
from torch.utils.data import DataLoader
from torch import nn
from torch.nn import Conv2d,MaxPool2d, Flatten, Linear
from torch.utils.tensorboard import SummaryWriter
import time

In [21]:
#准备数据集
train_data=torchvision.datasets.CIFAR10(root="./dataset",train=True,transform=torchvision.transforms.ToTensor(),download=True)
test_data=torchvision.datasets.CIFAR10(root="./dataset",train=False,transform=torchvision.transforms.ToTensor(),download=True)

In [22]:
#length长度
print("训练集长度：{}".format(len(train_data)))
print("测试集长度：{}".format(len(test_data)))

训练集长度：50000
测试集长度：10000


In [23]:
#利用Dataloader 来加载数据
train_loader=DataLoader(dataset=train_data,batch_size=64,shuffle=True,num_workers=0,drop_last=False)
test_loader=DataLoader(dataset=test_data,batch_size=64,shuffle=False,num_workers=0,drop_last=False)

In [33]:
#创建模型
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.model=nn.Sequential(
            Conv2d(in_channels=3,out_channels=32,kernel_size=5,padding=2),
            MaxPool2d(kernel_size=2),
            Conv2d(in_channels=32,out_channels=32,kernel_size=5,padding=2),
            MaxPool2d(kernel_size=2),
            Conv2d(in_channels=32,out_channels=64,kernel_size=5,padding=2),
            MaxPool2d(kernel_size=2),
            Flatten(),
            Linear(1024,64),
            Linear(64,10)
        )
    def forward(self, x):
        x=self.model(x)
        return x


# 使用GPU进行训练时，需要将模型，损失函数，优化器，数据加载器都移动到GPU上

In [45]:
#创建模型并进行训练、测试和保存
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
model=Model()
model.to(device)
loss_fn=nn.CrossEntropyLoss()
loss_fn.to(device)
optimizer=torch.optim.SGD(model.parameters(),lr=0.01)

writer=SummaryWriter("./tensorboard/CIFAR10_train")
for epoch in range(20):
    train_loss=0.0
    start_time=time.time()
    print("Epoch:{}".format(epoch+1))
    for data in train_loader:
        inputs,targets=data
        inputs=inputs.to(device)
        targets=targets.to(device)

        outputs=model(inputs)
        loss=loss_fn(outputs,targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss+=loss.item()
    end_time=time.time()
    writer.add_scalar("train_loss",train_loss,epoch+1)
    print("Epoch:{},Train Loss:{},Time:{}".format(epoch+1,train_loss,end_time-start_time))

    #对模型进行测试
    model.eval()
    test_loss=0.0
    test_accuracy=0.0
    with torch.no_grad():
        for data in test_loader:
            inputs,targets=data
            inputs=inputs.to(device)
            targets=targets.to(device)
            outputs=model(inputs)
            loss=loss_fn(outputs,targets)
            test_loss+=loss.item()
            accuracy=(outputs.argmax(1)==targets).sum()
            test_accuracy+=accuracy
    writer.add_scalar("test_loss",test_loss,epoch+1)
    print("Epoch:{},Test Loss:{},Test Accuracy:{}".format(epoch+1,test_loss,test_accuracy))

    #模型保存
    torch.save(model,"./model/cifar/cifar10_model_{}.pth".format(epoch+1))

writer.close()

Epoch:1
Epoch:1,Train Loss:1698.6954525709152,Time:10.128042936325073
Epoch:1,Test Loss:305.4164527654648,Test Accuracy:3064.0
Epoch:2
Epoch:2,Train Loss:1443.4094796180725,Time:9.8090181350708
Epoch:2,Test Loss:273.7366615533829,Test Accuracy:3802.0
Epoch:3
Epoch:3,Train Loss:1302.0689737796783,Time:10.120954751968384
Epoch:3,Test Loss:255.21330332756042,Test Accuracy:4133.0
Epoch:4
Epoch:4,Train Loss:1213.2358391284943,Time:10.157371044158936
Epoch:4,Test Loss:255.04289388656616,Test Accuracy:4390.0
Epoch:5
Epoch:5,Train Loss:1150.2297013998032,Time:9.996843099594116
Epoch:5,Test Loss:258.3719642162323,Test Accuracy:4284.0
Epoch:6
Epoch:6,Train Loss:1096.354645729065,Time:9.85084843635559
Epoch:6,Test Loss:223.201984167099,Test Accuracy:4886.0
Epoch:7
Epoch:7,Train Loss:1044.6497728824615,Time:9.843160152435303
Epoch:7,Test Loss:216.08834838867188,Test Accuracy:5150.0
Epoch:8
Epoch:8,Train Loss:997.1058881282806,Time:9.842959880828857
Epoch:8,Test Loss:202.46305298805237,Test Accurac