torch.utils.data.Dataset 介绍与实战:

https://blog.csdn.net/weixin_44211968/article/details/123744513


In [None]:
import os
import sys
import json

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from tqdm import tqdm
from PIL import Image

import timm

# from model import resnet34
# 缺少一个 vgg 模型
# from model import vgg

from torchvision import models

def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    # 获取图片路径
    data_root = os.path.abspath(os.path.join(os.getcwd(), "../分类数据")) 
    image_path = os.path.join(data_root, "eggDataset") 
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)

    # 图像增强
    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]), 
        "val": transforms.Compose([transforms.Resize(256), # ResNet 正则化参数
                                   transforms.CenterCrop(224),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
    
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"])

    class_list = train_dataset.class_to_idx
    print('class_list', class_list)

    cla_dict = dict((val, key) for key, val in class_list.items())
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)
    

    batch_size = 10
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                                               num_workers=nw)


    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                           transform=data_transform["val"])

    train_num = len(train_dataset)
    test_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=batch_size, shuffle=False,
                                                  num_workers=nw)

    print("using {} images for training, {} images for validation.".format(train_num, test_num))

    net = timm.create_model('efficientnet_b0', pretrained=True)
    # net = models.EfficientNet(pretrained=True) # 加载预训练模型
    model_name = "EfficientNet"
    # net = vgg(model_name=model_name, num_classes=4, init_weights=True) # net.fc.in_features, nn.Linear(in_channel, 2) 
    net.to(device)

    # define loss function
    loss_function = nn.CrossEntropyLoss()

    # construct an optimizer, optimize params that are required
    params = [p for p in net.parameters() if p.requires_grad] # requires_grad 表示一个Tensor是否需要计算梯度, tensor可分为两类：叶子节点和非叶子节点
    # 判断节点 requires_grad, 当打开 net.train() 的时候，就会有节点 requires_grad
    optimizer = optim.Adam(params, lr=0.0001) # 优化器对象Optimizer，用来保存当前的状态，并能够根据计算得到的梯度来更新参数

    epochs = 5
    best_acc = 0.0
    # save_path = './resNet34.pth'
    save_path = './{}.pth'.format(model_name)
    train_steps = len(train_loader) # 
    for epoch in range(epochs):
        # train
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout) # train_loader 数据分批次后,,,
        for step, data in enumerate(train_bar):
            
            images, labels = data
            # images: # tensor([[[[-0.5938, -0.7993, -0.2171,  ..., -0.7479, -0.7650, -0.7650],
                      # [-0.7650, -0.4739, -0.1657,  ..., -0.6794, -0.7137, -0.7308],
                      # [-0.5938, -0.1828, -0.1999,  ..., -0.6794, -0.6794, -0.7137], ......
            optimizer.zero_grad()
            logits = net(images.to(device))
            loss = loss_function(logits, labels.to(device))
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()

            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)

        # validate
        # 和训练最大的区别就是没有反向传播
        net.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                # loss = loss_function(outputs, test_labels)
                predict_y = torch.max(outputs, dim=1)[1] # 获取每一行最大值(每个对应所属类别的概率)的索引值 https://blog.csdn.net/pengchengliu/article/details/118928741

                # 会自动选取概率最大的那个值
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item() # ......

                val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1, epochs)

        val_accurate = acc / test_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)

    print('Finished Training')

if __name__ == '__main__':
    main()
                        