### PSPNet用のDatasetを作成

前処理クラス

In [1]:
from data_argmentation_for_PSPNet import (
    Compose,
    Scale,
    RandomRotation,
    RandomMirror,
    Resize,
    Normalize_Tensor
)

In [2]:
class DataTransform():
    """
    画像とアノテーションの前処理クラス.訓練時と検証時で異なる動作をする.
    画像サイズ(input_size, input_size)にする.
    訓練時は水増しを行う.
    
    Attributes
    ----------
    input_size : int リサイズ先の大きさ.
    color_mean : (R,G,B) チャネルの平均値
    color_std  : (R,G,B) チャネルの標準偏差
    """
    def __init__(self, input_size, color_mean, color_std):
        self.data_transform = {
            'train': Compose([
                Scale(scale=[0.5,1.5]),
                RandomRotation(angle=[-10,10]),
                RandomMirror(),
                Resize(input_size),
                Normalize_Tensor(color_mean, color_std)
            ]),
            'val': Compose([
                Resize(input_size),
                Normalize_Tensor(color_mean, color_std)
            ])
        }
        
    def __call__(self, phase, img, anno_class_img):
        return self.data_transform[phase](img, anno_class_img)

### VOCDataset

In [3]:
from torch.utils.data import Dataset
from PIL import Image

In [19]:
class VOCDataset(Dataset):
    """
    VOC2012のDatasetを作成するクラス.PytorchのDatasetクラスを継承
    
    Attributes
    ----------
    img_list : list 画像のパスを格納したリスト
    anno_list : list アノテーションへのパスを格納したリスト
    phase : 'train' or 'val'
    transform : object 前処理クラス
    """
    
    def __init__(self, img_list, anno_list, phase, transform):
        self.img_list = img_list
        self.anno_list = anno_list
        self.phase = phase
        self.transform = transform
        
    def __len__(self):
        return len(self.img_list)
    
    def __getitem__(self, index):
        img, anno_class = self.pull_item(index)
        return img, anno_class
        
    def pull_item(self, index):
        """
        画像のTensor形式データ、アノテーションを取得
        """
        image_file_path = self.img_list[index]
        img = Image.open(image_file_path) # (RGB)
        
        anno_file_path = self.anno_list[index]
        anno_class_img = Image.open(anno_file_path) #[h][w][カラーパレット番号]
        
        img, anno_class_img = self.transform(self.phase, img, anno_class_img)
#         print(img, anno_class_img)
        
        return img, anno_class_img

In [20]:
from datapath_for_PSPNet import make_datapath_list

In [22]:
# 動作確認

color_mean = (0.485, 0.456, 0.406)
color_std = (0.229, 0.224, 0.225)

rootpath = "../data/VOCdevkit/VOC2012/"

train_img_list, train_anno_list, val_img_list, val_anno_list \
= make_datapath_list(rootpath=rootpath)

# print(f"train_img_list: {train_img_list}")

# dataset
train_transform = DataTransform(input_size=475, 
                                color_mean=color_mean, 
                                color_std=color_std)
train_dataset = VOCDataset(train_img_list, train_anno_list, 
                           phase='train', transform=train_transform)


val_transform = DataTransform(input_size=475,
                              color_mean=color_mean, 
                              color_std=color_std)
val_dataset = VOCDataset(val_img_list, val_anno_list, 
                         phase='val', transform=val_transform)

# データの取り出し
print(val_dataset.__getitem__(0)[0].shape)
print(val_dataset.__getitem__(0)[1].shape)
print(val_dataset.__getitem__(0))


torch.Size([3, 475, 475])
torch.Size([475, 475])
(tensor([[[ 1.6667,  1.5125,  1.5639,  ...,  1.7523,  1.6667,  1.7009],
         [ 1.5810,  1.4269,  1.4783,  ...,  1.7009,  1.6153,  1.6495],
         [ 1.5639,  1.4098,  1.4440,  ...,  1.6838,  1.5982,  1.6324],
         ...,
         [-0.4739, -0.4911, -0.5424,  ...,  1.2557,  1.1872,  1.2214],
         [-0.5596, -0.4911, -0.4911,  ...,  1.2385,  1.1872,  1.2214],
         [-0.6281, -0.3883, -0.3369,  ...,  1.2385,  1.1872,  1.2214]],

        [[ 1.8333,  1.6758,  1.7283,  ...,  1.9209,  1.8333,  1.8683],
         [ 1.7458,  1.5882,  1.6408,  ...,  1.8683,  1.7808,  1.8158],
         [ 1.7283,  1.5707,  1.6057,  ...,  1.8508,  1.7633,  1.7983],
         ...,
         [-0.5826, -0.6001, -0.6527,  ...,  1.4132,  1.3431,  1.3431],
         [-0.6702, -0.6001, -0.6001,  ...,  1.3957,  1.3431,  1.3431],
         [-0.7402, -0.4951, -0.4426,  ...,  1.3957,  1.3431,  1.3431]],

        [[ 2.0474,  1.8905,  1.9428,  ...,  2.1346,  2.0474,  2.08

### Dataloaderの作成

In [23]:
from torch.utils.data import DataLoader

In [24]:
batch_size = 8

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

dataloader_dict = {'train': train_dataloader, 'val': val_dataloader}

# 動作確認
batch_iterator = iter(dataloader_dict['val']) # イテレータに変換
images, anno_class_images = next(batch_iterator) # 1番目の要素を取り出す
print(images.size())
print(anno_class_images.size())


torch.Size([8, 3, 475, 475])
torch.Size([8, 475, 475])
