In [None]:

import sys
import torch
from torch.utils.data import Dataset, DataLoader

from PIL import Image
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from model import resnet34

import os

import albumentations as A
from albumentations.pytorch import ToTensorV2

import numpy as np

# 自定义数据集类，要实现两个方法，__len__和__getiten__方法
#  1. __len__ :实现len(dataset)返回整个数据集的大小。
# 2. __getitem__用来获取一些索引的数据，使dataset[i]返回数据集中第i个样本。
class CustomDataset(Dataset):
    def __init__(self, data_path, transform=None):
        self.data_path = data_path
        self.transform = transform
        self.data = []  # 存储图像路径和相应标签的列表

        # 获取类别列表
        classes = os.listdir(data_path)

        for class_name in classes:
            class_path = os.path.join(data_path, class_name)
            if os.path.isdir(class_path):
                class_label = classes.index(class_name)
                for img_filename in os.listdir(class_path):
                    img_path = os.path.join(class_path, img_filename)
                    self.data.append((img_path, class_label))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path, label = self.data[idx]
        img = Image.open(img_path).convert('RGB')

        if self.transform:
            transformed = self.transform(image=np.array(img))  # Pass the image as a numpy array
            img = transformed['image']

        return img, label


## 使用albumentations中的数据增强方法

In [None]:
# 数据集路径和转换
data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
image_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set path

data_transform = {
    "train": A.Compose(
    [
        A.SmallestMaxSize(max_size=160),
        A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
        A.RandomCrop(height=128, width=128),
        A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),
        A.RandomBrightnessContrast(p=0.5),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ]
),
    "val": A.Compose(
    [
        A.SmallestMaxSize(max_size=160),
        A.CenterCrop(height=128, width=128),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ]
)}

# 创建自定义数据集实例
train_dataset = CustomDataset(os.path.join(image_path, "train"), transform=data_transform["train"])

# 创建数据加载器
batch_size = 32
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)

validate_dataset = CustomDataset(os.path.join(image_path, "val"),transform=data_transform["val"])

val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                              batch_size=batch_size, shuffle=False,
                                              num_workers=0)
