In [5]:
# 从原始图像文件开始读取
# 并将它们转换为张量格式
# https://www.kaggle.com/c/cifar-10
# https://www.zhihu.com/question/54883612/answer/130707137363

import os
import shutil
from collections import Counter

# 数据文件路径
data_files_dir = r'../data/kaggle-cifar-10'
train_images_dir = os.path.join(data_files_dir, 'train')
test_images_dir = os.path.join(data_files_dir, 'test')
train_labels_file_path = os.path.join(data_files_dir, 'trainLabels.csv')
submission_example_file_path = os.path.join(data_files_dir, 'sampleSubmission.csv')


def read_csv_labels(file_path) -> dict:
    """
    读取CSV文件，返回文件名到标签的字典
    :param file_path: CSV文件路径
    :return: 字典
    """
    with open(file_path, 'r', encoding='utf-8') as f:
        lines = f.readlines()[1:]  # 跳过第一行标题
    # 遍历每一行，将其去掉末尾的换行符，并按逗号分割成列表
    tokens = [line.strip().split(',') for line in lines]
    # 创建字典，键是文件名，值是标签
    return dict(((name, label) for name, label in tokens))


train_labels = read_csv_labels(train_labels_file_path)
print('# 训练样本 :', len(train_labels))
print('# 类别数 :', len(set(train_labels.values())))
dict(list(train_labels.items())[:10])  # 显示前10个样本的文件名和标签

# 训练样本 : 50000
# 类别数 : 10


{'1': 'frog',
 '2': 'truck',
 '3': 'truck',
 '4': 'deer',
 '5': 'automobile',
 '6': 'automobile',
 '7': 'bird',
 '8': 'horse',
 '9': 'ship',
 '10': 'cat'}

In [8]:
# 显示标签种类分布
Counter(train_labels.values())

Counter({'frog': 5000,
         'truck': 5000,
         'deer': 5000,
         'automobile': 5000,
         'bird': 5000,
         'horse': 5000,
         'ship': 5000,
         'cat': 5000,
         'dog': 5000,
         'airplane': 5000})

In [None]:
def copy_files(file_path, target_dir):
    """
    将文件复制到目标目录
    :param file_path: 源文件路径
    :param target_dir: 目标文件夹
    """
    os.makedirs(target_dir, exist_ok=True)  # 确保目标目录存在，如果不存在则创建
    shutil.copy(file_path, target_dir)


def reorganize_train_valid(whole_train_dir: str,
                           labels: dict,
                           valid_ratio: float = 0.2):
    """
    将训练集划分为训练集和验证集
    :param whole_train_dir: 训练集目录
    :param labels: 标签字典 {文件名: 类别}
    :param valid_ratio: 验证集比例
    """
    label_counts = Counter(labels.values()) # 统计训练数据集中每个类别的样本数
    min_count = min(label_counts.values()) # 获取样本最少的类别的样本数
    valid_count_per_label = max(1, int(min_count * valid_ratio)) # 每个类比在验证集中至少要有的样本数
    files_by_label = {}

    os.makedirs(whole_train_dir, exist_ok=True)  # 确保验证集目录存在，如果不存在则创建
    for label in set(labels.values()):
        label_dir = os.path.join(train_dir, label)
        os.makedirs(label_dir, exist_ok=True)  # 确保标签目录存在，如果不存在则创建
        # 获取当前标签的所有文件名
        files = [name for name, l in labels.items() if l == label]
        # 随机选择一部分文件作为验证集
        valid_files = files[:int(len(files) * valid_ratio)]
        # 将验证集文件复制到验证集目录
        for file in valid_files:
            copy_files(os.path.join(train_dir, file), os.path.join(valid_dir, label))
            files.remove(file)  # 从训练集中删除这些文件