In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torchvision
from torchvision import transforms
from torchvision.datasets import ImageFolder
import matplotlib.pyplot as plt                            # 显示图片
from PIL import Image                                      # 读取图片

In [None]:
# 将本地数据划分为训练集和测试集
normalize = transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
# 因为全连接层是固定尺寸的输入输出，所以在卷积层之前的输入要求是固定的，CIFAR10 32x32
train_transformer = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.CenterCrop(32),
    transforms.ToTensor(),
    normalize
])

test_transformer = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.CenterCrop(32),
    transforms.ToTensor(),
    normalize
])


class MyDataset(Dataset):
    # 参数预定义
    def __init__(self, filenames, labels, transform):
        self.filenames = filenames
        self.labels = labels
        self.transform = transform

    # 返回图片个数
    def __len__(self):
        return len(self.filenames)

    # 获取每个图片
    def __getitem__(self, idx):
        image = Image.open(self.filenames[idx]).convert('RGB')
        image = self.transform(image)
        return image, self.labels[idx]


def split_train_test_data(data_dir, ratio):
    dataset = ImageFolder(data_dir)                        # ratio的和为1，data_dir精确到分类目录的上一级
    character = [[] for i in range(len(dataset.classes))]
    for x, y in dataset.samples:                           # 将数据按类标签存放
        character[y].append(x)
    # print(dataset.samples)

    train_inputs, test_inputs = [], []
    train_labels, test_labels = [], []
    for i, data in enumerate(character):                   # data为一类图片
        num_sample_train = int(len(data) * ratio[0])
        num_sample_test = int(len(data) * ratio[1])

        num_test_index = num_sample_train + num_sample_test

        for x in data[:num_sample_train]:
            train_inputs.append(str(x))
            train_labels.append(i)
        for x in data[num_sample_train:num_test_index]:
            test_inputs.append(str(x))
            test_labels.append(i)

    train_loader = DataLoader(MyDataset(train_inputs, train_labels, train_transformer),
                              batch_size=100, shuffle=True)   # batch_size 一次性读入多少批量的图片
    test_loader = DataLoader(MyDataset(test_inputs, test_labels, test_transformer),
                             batch_size=100, shuffle=False)

    return train_loader, test_loader


# 划分训练集和测试集
train_loader, test_loader = split_train_test_data('E:\\datasets\\Birds', [0.8, 0.2])