In [4]:
# 优秀的训练函数
# 参数初始化、动态修改学习率、优化器选择
# https://www.zhihu.com/question/523869554/answer/2560312612
import torch
from torch import nn, optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
from torchinfo import summary
import torchvision
from torchvision import datasets, transforms
from matplotlib import pyplot as plt
from tqdm import tqdm
import numpy as np


def get_dataloader(batch_size):
    transform = {
        'train': transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]),
        'valid': transforms.Compose([transforms.Resize(256),
                                     transforms.CenterCrop(224),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ])
    }
    train_ds = datasets.CIFAR10(root='../data', train=True, download=False, transform=transform['train'])
    test_ds = datasets.CIFAR10(root='../data', train=False, download=False, transform=transform['valid'])
    class_to_idx = train_ds.class_to_idx
    class_names = train_ds.classes
    print(f'分类标签为：{class_to_idx}')
    print(f'训练集数据量：{len(train_ds)}')
    print(f'测试集数据量: {len(test_ds)}')
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=True)
    return class_to_idx, class_names, train_loader, test_loader


class2int, classes, train_dl, test_dl = get_dataloader(batch_size=64)

分类标签为：{'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3, 'deer': 4, 'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9}
训练集数据量：50000
测试集数据量: 10000


In [None]:
def draw_example(dataloader):
    #展示dataloader里的6张图片
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def train(net,
          device,
          train_loader,
          valid_loader,
          batch_size,
          max_epoch,
          lr,
          lr_min,
          criterion,
          optimizer_type='sgd',
          scheduler_type='cosine',
          init=True):
    def init_xavier(m):
        if type(m) == nn.Linear or type(m) == nn.Conv2d:
            nn.init.xavier_uniform_(m.weight)

    if init:
        net.apply(init_xavier)
    print(f'Training on device: {device}')
    net.to(device)

    if optimizer_type == 'sgd':
        optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9)