In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from torch.utils.data import DataLoader,Dataset
from torch import optim
import os
import csv
from PIL import Image
import warnings
warnings.simplefilter('ignore')
from torchvision import datasets

#载入数据
trans = transforms.Compose((transforms.Resize((32,32)),transforms.ToTensor()))
train_set = datasets.MNIST('./num',train=True,transform=trans, download=true)
#mnist中的test_set一共有1万张照片，这里我们把前5000张用作validation_set,后5000张用作test_set
val_set = list(datasets.MNIST('./num',train=False,transform=trans, download=true))[:5000]
test_set = list(datasets.MNIST('./num',train=False,transform=trans, download=true))[5000:]

train_loader = DataLoader(train_set,batch_size=150,shuffle=True)
val_loader = DataLoader(val_set,batch_size=50,shuffle=True)
test_loader = DataLoader(test_set,batch_size=50,shuffle=True)


#构建resblock
class resblock(nn.Module):
    def __init__(self,ch_in,ch_out,stride=1):
        super(resblock,self).__init__()
        self.conv_1 = nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=stride,padding=1)
        self.bn_1 = nn.BatchNorm2d(ch_out)
        self.conv_2 = nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1)
        self.bn_2 = nn.BatchNorm2d(ch_out)
        self.ch_trans = nn.Sequential()
        if ch_in != ch_out:
            self.ch_trans = nn.Sequential(nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=stride),nn.BatchNorm2d(ch_out))
        #ch_trans表示通道数转变。因为要做short_cut,所以x_pro和x_ch的size应该完全一致
        
    def  forward(self,x):
        x_pro = F.relu(self.bn_1(self.conv_1(x)))
        x_pro = self.bn_2(self.conv_2(x_pro))
        
        #short_cut:
        x_ch = self.ch_trans(x)
        out = x_pro + x_ch
        out = F.relu(out)
        return out 
    
    
#搭建resnet
class Resnet18(nn.Module):
    def __init__(self,num_class):
        super(Resnet18,self).__init__()
        self.conv_1 = nn.Sequential(
        nn.Conv2d(1,16,kernel_size=3,stride=3,padding=0),
        nn.BatchNorm2d(16))
        self.block1 = resblock(16,32,1) 
        self.block2 = resblock(32,64,1) 
        self.block3 = resblock(64,128,2)
        self.block4 = resblock(128,256,2)
        self.outlayer = nn.Linear(256*3*3,num_class)#这个256*3*3是根据forward中x经过4个resblock之后来决定的
        
    def forward(self,x):
        x = F.relu(self.conv_1(x))
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = x.reshape(x.size(0),-1) #进行打平操作
        result = self.outlayer(x)
        return result       

RuntimeError: Dataset not found. You can use download=True to download it

In [None]:
device = torch.device('cuda')
model = Resnet18(10).to(device) #模型初始化，10代表一共有10种类别
print('模型需要训练的参数共有{}个'.format(sum(map(lambda p:p.numel(),model.parameters()))))
loss_fn = nn.CrossEntropyLoss() #选择loss_function
optimizer = optim.Adam(model.parameters(),lr=1e-3) #选择优化方式

In [None]:
# evaluate用于检测模型的预测效果，validation_set和test_set是同样的evaluate方法
def evaluate(model,loader):
    correct_num = 0
    total_num = len(loader.dataset)
    for img,label in loader: #lodaer中包含了很多batch，每个batch有32张图片
        img,label = img.to(device),label.to(device)
        with torch.no_grad():
            logits = model(img)
            pre_label = logits.argmax(dim=1)
        correct_num += torch.eq(pre_label,label).sum().float().item()
    
    return correct_num/total_num 



best_epoch,best_acc = 0,0
for epoch in range(10): #时间关系，我们只训练10个epoch
    for batch_num,(img,label) in enumerate(train_loader):
        #img.size [b,3,224,224]  label.size [b]
        img,label = img.to(device),label.to(device)
        logits = model(img)
        loss = loss_fn(logits,label)
        if (batch_num+1)%100 == 0:
            print('这是第{}次迭代的第{}个batch,loss是{}'.format(epoch+1,batch_num+1,loss.item()))
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    if epoch%2==0: #这里设置的是每训练两次epoch就进行一次validation
        val_acc = evaluate(model,val_loader)
        #如果val_acc比之前的好，那么就把该epoch保存下来，并把此时模型的参数保存到指定txt文件里
        if val_acc>best_acc:
            print('验证集上的准确率是：{}'.format(val_acc))
            best_epoch = epoch
            best_acc = val_acc
            torch.save(model.state_dict(),'mnist_resnet_ckp.txt')
    

print('best_acc:{},best_epoch:{}'.format(best_acc,best_epoch))
model.load_state_dict(torch.load('mnist_resnet_ckp.txt'))
print('模型训练完毕，已将参数设置成训练过程中的最优值，现在开始测试test_set')

test_acc = evaluate(model,test_loader)
print('测试集上的准确率是：{}'.format(test_acc))