@Author : kkutysllb

@E-mail : libing1@sn.chinamobile.com，31468130@qq.com

@Date   : 2024-11-26 10:15

@Desc   : AleNet图片分类

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import CIFAR10
from torchvision import transforms
import sys
import os
sys.path.append('../')
from kk_libraries.kk_functions import get_device, kk_ImageClassifierTrainer
from kk_libraries.kk_models import kk_init_weights_relu
from kk_libraries.kk_dataprocess import kk_load_data, kk_predict_images_labels
from kk_libraries.kk_constants import text_labels_cifar10, mean, std


In [2]:
# 定义模型
class AlexNet(nn.Module):
    def __init__(self, in_channels=3, num_classes=10):
        super(AlexNet, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, 96, kernel_size=11, stride=4, padding=0),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2)
        )
        self.conv2 = nn.Sequential(
             nn.Conv2d(96, 256, kernel_size=5, stride=1, padding=2),
             nn.ReLU(),
             nn.MaxPool2d(kernel_size=3, stride=2)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1),
            nn.ReLU()
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(384, 384, kernel_size=3, stride=1, padding=1),
            nn.ReLU()
        )
        self.conv5 = nn.Sequential(
            nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2)
        )
        self.fc = nn.Sequential(
            nn.Linear(256*5*5, 4096),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(4096, num_classes)
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = x.view(x.size(0), 256*5*5)
        x = self.fc(x)
        return x


In [11]:
# 定义数据预处理
def kk_data_transform():
    return {
        'train': transforms.Compose([transforms.RandomResizedCrop(224),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor(),
                                      transforms.Normalize(mean, std)]),
        'valid': transforms.Compose([transforms.Resize(256), 
                                     transforms.CenterCrop(224), 
                                     transforms.ToTensor(),
                                     transforms.Normalize(mean, std)])
    }

In [13]:
# 数据加载
data_path = os.path.join('../', "data/CIFAR10")
train_loader, valid_loader, test_loader = kk_load_data(data_path, ratio=0.15, batch_size=256, DataSets=CIFAR10, transform=kk_data_transform())

Files already downloaded and verified
Files already downloaded and verified
训练集大小: 42500, 验证集大小: 7500, 测试集大小: 10000


In [20]:
class Config():
    """配置类"""
    def __init__(self):
        self.num_epochs = 100
        self.patience = 20
        self.device = get_device()
        self.save_path = os.path.join('../', "models", "AlexNet")
        self.logs_path = os.path.join('../', "logs", "AlexNet")
        self.plot_titles = 'AlexNet'
        self.class_list = text_labels_cifar10

In [23]:
# 定义超参数
epochs = 100
lr = 0.01
device = get_device()
# 定义模型
model = AlexNet(in_channels=3, num_classes=10).to(device)
model.apply(kk_init_weights_relu)
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.99)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.3)


In [22]:
# 训练
config = Config()
trainer = kk_ImageClassifierTrainer(config, model, criterion, optimizer, scheduler)
trainer.train_iter(train_loader, valid_loader)


Epoch: 【1/100】
Iter 0      训练损失: 2.3327, 训练精度: 8.203%, 验证精度: 10.160%, 模型优化: * 训练设备: mps, 学习率: 0.010000000
Iter 100    训练损失: 2.0506, 训练精度: 23.403%, 验证精度: 29.587%, 模型优化: * 训练设备: mps, 学习率: 0.010000000
Epoch: 【2/100】
Iter 200    训练损失: 1.9951, 训练精度: 25.928%, 验证精度: 26.293%, 模型优化:  训练设备: mps, 学习率: 0.010000000
Iter 300    训练损失: 1.9555, 训练精度: 27.285%, 验证精度: 29.600%, 模型优化: * 训练设备: mps, 学习率: 0.010000000
Epoch: 【3/100】
Iter 400    训练损失: 1.9366, 训练精度: 27.916%, 验证精度: 30.427%, 模型优化: * 训练设备: mps, 学习率: 0.010000000
Iter 500    训练损失: 1.9124, 训练精度: 28.955%, 验证精度: 34.173%, 模型优化: * 训练设备: mps, 学习率: 0.010000000
Epoch: 【4/100】
Iter 600    训练损失: 1.9143, 训练精度: 28.811%, 验证精度: 30.440%, 模型优化:  训练设备: mps, 学习率: 0.010000000
Epoch: 【5/100】
Iter 700    训练损失: 1.9096, 训练精度: 29.070%, 验证精度: 25.120%, 模型优化:  训练设备: mps, 学习率: 0.010000000
Iter 800    训练损失: 1.9040, 训练精度: 29.262%, 验证精度: 33.520%, 模型优化:  训练设备: mps, 学习率: 0.010000000
Epoch: 【6/100】
Iter 900    训练损失: 1.8982, 训练精度: 29.467%, 验证精度: 31.400%, 模型优化:  训练设备: mps, 学习率: 0.010000