# 使用skimage的接口读取图片数据
- 构建dataset及dataloader

In [1]:
import os
import pandas as pd
import torch
from torch.utils.data import Dataset
from skimage import io


- 定义DataSet
- 通过skimage.imread()函数读取出来的类型是<class 'numpy.ndarray'>,对于rgb图像来说 shape：width，height, chan_num

In [16]:
class CatsAndDogsDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.annotations = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform
        
    def __len__(self):
        return len(self.annotations) #25000
    
    def __getitem__(self, index):
        img_path = os.path.join(self.root_dir, self.annotations.iloc[index, 0]) 
        #csv file中存储的是图片文件名及类别，第0列存储的是图片名称
        image = io.imread(img_path)
        #print(type(image)) <class 'numpy.ndarray'>
        #print(image.shape) #(224, 224, 3),分别是width height, chan_num
        y_label = torch.tensor(int(self.annotations.iloc[index, 1])) #第一列存储的是类别
        
        if self.transform:
            image = self.transform(image) #transform 中一定要包含ToTensor
        
        return image, y_label

- 创建Dataset及dataloader

In [17]:
import torch
import torchvision.transforms as transforms
import torchvision
from torch.utils.data import DataLoader

dataset = CatsAndDogsDataset(csv_file = 'cats_dogs.csv', root_dir = 'cats_dogs_resized',
                            transform = transforms.ToTensor())
#print(len(dataset)) #10
train_set, test_set = torch.utils.data.random_split(dataset, [8, 2]) 
#dataset中总共有10张图片，我们将数据分成两部分，8张用来训练，另外2张用来测试

batch_size = 2
train_loader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_set, batch_size=batch_size, shuffle=True)

image, label = next(iter(train_loader))
print(len(image))
print(image.shape) #torch.Size([2, 3, 224, 224]),分别是batch_size, chan_num, width, height
print(label)

for batch_idx, (data, targets) in enumerate(train_loader):
    print(batch_idx)
    print(data.shape)
    print(targets)

(224, 224, 3)
(224, 224, 3)
2
torch.Size([2, 3, 224, 224])
tensor([1, 0])
(224, 224, 3)
(224, 224, 3)
0
torch.Size([2, 3, 224, 224])
tensor([0, 0])
(224, 224, 3)
(224, 224, 3)
1
torch.Size([2, 3, 224, 224])
tensor([0, 1])
(224, 224, 3)
(224, 224, 3)
2
torch.Size([2, 3, 224, 224])
tensor([0, 0])
(224, 224, 3)
(224, 224, 3)
3
torch.Size([2, 3, 224, 224])
tensor([0, 1])
