In [3]:
# 对 cinic-10 数据集进行随机抽样保存（train 抽 5%，test 抽 50%）
import os
import random
import shutil
def sample_cinic10_subset(source_dir, output_dir, train_ratio=0.05, test_ratio=0.5, seed=42):
    random.seed(seed)
    os.makedirs(output_dir, exist_ok=True)
    for split in ['train', 'test']:
        split_src_dir = os.path.join(source_dir, split)
        if not os.path.isdir(split_src_dir):
            continue
        split_dst_dir = os.path.join(output_dir, split)
        os.makedirs(split_dst_dir, exist_ok=True)
        classes = [cls for cls in os.listdir(split_src_dir) if os.path.isdir(os.path.join(split_src_dir, cls))]
        for cls in classes:
            cls_src_dir = os.path.join(split_src_dir, cls)
            cls_dst_dir = os.path.join(split_dst_dir, cls)
            os.makedirs(cls_dst_dir, exist_ok=True)
            images = [img for img in os.listdir(cls_src_dir) if os.path.isfile(os.path.join(cls_src_dir, img))]
            random.shuffle(images)
            ratio = train_ratio if split == 'train' else test_ratio
            num_select = int(len(images) * ratio)
            for img_name in images[:num_select]:
                src_path = os.path.join(cls_src_dir, img_name)
                dst_path = os.path.join(cls_dst_dir, img_name)
                shutil.copyfile(src_path, dst_path)

# 使用当前工作目录代替
script_dir = os.getcwd()
sample_cinic10_subset(
    source_dir=os.path.join(script_dir, 'datasets/cinic-10'),
    output_dir=os.path.join(script_dir, 'datasets/cnn/cinic-10')
)

In [8]:
# CIFAR-10 转换为 ImageFolder/CINIC 格式（按类分文件夹，按 train/test 分目录）
import os
from torchvision.datasets import CIFAR10
import shutil

def convert_cifar10_to_imagefolder(root_cifar, output_dir):
    dataset_train = CIFAR10(root=root_cifar, train=True, download=False)
    dataset_test = CIFAR10(root=root_cifar, train=False, download=False)

    for split, dataset in [('train', dataset_train), ('test', dataset_test)]:
        for idx, (img, label) in enumerate(dataset):
            class_name = str(label)
            class_dir = os.path.join(output_dir, split, class_name)
            os.makedirs(class_dir, exist_ok=True)
            img_path = os.path.join(class_dir, f"{idx}.png")
            img.save(img_path)

# 使用当前脚本路径拼接相对路径
script_dir = os.getcwd()
root_cifar = os.path.join(script_dir, 'datasets/cifar10')  # 原始数据(同mnist那种乱序结构)
output_dir = os.path.join(script_dir, 'datasets/cifar-10')  # 转换后输出路径(同CINIC那种按类的结构)

convert_cifar10_to_imagefolder(root_cifar, output_dir)

In [9]:
# 对 cifar-10 数据集进行随机抽样保存（train 抽 50%，test 抽 50%）
import os
import random
import numpy as np
import pandas as pd
import socket
import pickle
from time import sleep
from sklearn.preprocessing import MinMaxScaler, StandardScaler
import torch
from torchvision.datasets import ImageFolder
from torchvision import transforms
from PIL import Image
import shutil

def sample_cifar10_subset(source_dir, output_dir, train_ratio=0.5, test_ratio=0.5, seed=42):
    random.seed(seed)
    os.makedirs(output_dir, exist_ok=True)
    for split in ['train', 'test']:
        split_src_dir = os.path.join(source_dir, split)
        if not os.path.isdir(split_src_dir):
            continue
        split_dst_dir = os.path.join(output_dir, split)
        os.makedirs(split_dst_dir, exist_ok=True)
        classes = [cls for cls in os.listdir(split_src_dir) if os.path.isdir(os.path.join(split_src_dir, cls))]
        for cls in classes:
            cls_src_dir = os.path.join(split_src_dir, cls)
            cls_dst_dir = os.path.join(split_dst_dir, cls)
            os.makedirs(cls_dst_dir, exist_ok=True)
            images = [img for img in os.listdir(cls_src_dir) if os.path.isfile(os.path.join(cls_src_dir, img))]
            random.shuffle(images)
            ratio = train_ratio if split == 'train' else test_ratio
            num_select = int(len(images) * ratio)
            for img_name in images[:num_select]:
                src_path = os.path.join(cls_src_dir, img_name)
                dst_path = os.path.join(cls_dst_dir, img_name)
                shutil.copyfile(src_path, dst_path)

# 使用当前工作目录代替
script_dir = os.getcwd()
sample_cinic10_subset(
    source_dir=os.path.join(script_dir, 'datasets/cifar-10'),
    output_dir=os.path.join(script_dir, 'datasets/cnn/cifar-10')
)