In [None]:
def make_datapath_list(phase="train"):
    """
    phase:train or val
    returns:
    path_list(list)
    """
    
    rootpath = "./data/hymenoptera_data/"
    target_path = osp.join(rootpath + phase + '/**/*.jpg')
    print(target_path)
    
    path_list = []
    
    for path in glob.glob(target_path):
        path_list.append(path)
        
    return path_list

train_list = make_datapath_list(phase="train")
val_list = make_datapath_list(phase="val")

train_list

In [None]:
# ありとはちの画像のDataset

class HymenopteraDataset(data.Dataset):
    """
    アリとハチの画像のDataset class，PyTorchのDatasetクラスを継承．
    
    Attributes
    file_list list
        画像のパスを格納したリスト
    transform object
        前処理クラスのインスタンス
    phase 'train' or 'eval'
        学習か訓練かを設定
    """
    
    def __init__(self, file_list, transform, phase):
        self.file_list = file_list
        self.transform = transform
        self.phase = phase
        
    def __len__(self):
        return len(self.file_list)
    
    def __getitem__(self, index):
        '''
        前処理をした画像のTensor形式のデータトラベルを取得
        '''
        
        # index番目の画像をload
        img_path = self.file_list[index]
        img = Image.open(img_path)
        
        # 前処理
        img_transformed = self.transform(img, self.phase) # torch.Size([3, 224, 224])
        
        if self.phase == "train":
            label = img_path[30:34]
        elif self.phase == "val":
            label = img_path[28:32]
            
        # ラベルを数値に変換する
        if label == "ants":
            label = 0
        elif label == "bees":
            label = 1
            
        return img_transformed, label
    
# execute
train_dataset = HymenopteraDataset(file_list = train_list, transform=ImageTransform(size, mean, std), phase="train")
val_dataset = HymenopteraDataset(file_list = val_list, transform=ImageTransform(size, mean, std), phase="val")

# test
index = 0
print(train_dataset.__getitem__(index)[0].size())
print(train_dataset.__getitem__(index)[1])

    
        