### PSPNet用のDatasetを作成

前処理クラス

In [1]:
from data_argmentation_for_PSPNet import (
    Compose,
    Scale,
    RandomRotation,
    RandomMirror,
    Resize,
    Normalize_Tensor
)

In [2]:
class DataTransform():
    """
    画像とアノテーションの前処理クラス.訓練時と検証時で異なる動作をする.
    画像サイズ(input_size, input_size)にする.
    訓練時は水増しを行う.
    
    Attributes
    ----------
    input_size : int リサイズ先の大きさ.
    color_mean : (R,G,B) チャネルの平均値
    color_std  : (R,G,B) チャネルの標準偏差
    """
    def __init__(self, input_size, color_mean, color_std):
        self.data_transform = {
            'train': Compose([
                Scale(scale=[0.5,1.5]),
                RandomRotation(angle=[-10,10]),
                RandomMirror(),
                Resize(input_size),
                Normalize_Tensor(color_mean, color_std)
            ]),
            'val': Compose([
                Resize(input_size),
                Normalize_Tensor(color_mean, color_std)
            ])
        }
        
    def __call__(self, phase, img, anno_class_img):
        return self.data_transform[phase](img, anno_class_img)

### VOCDataset

In [3]:
from torch.utils.data import Dataset
from PIL import Image

In [4]:
class VOCDataset(Dataset):
    """
    VOC2012のDatasetを作成するクラス.PytorchのDatasetクラスを継承
    
    Attributes
    ----------
    img_list : list 画像のパスを格納したリスト
    anno_list : list アノテーションへのパスを格納したリスト
    phase : 'train' or 'val'
    transform : object 前処理クラス
    """
    
    def __init__(self, img_list, anno_list, phase, transform):
        self.img_list = img_list
        self.anno_list = anno_list
        self.phase = phase
        self.transform = transform
        
    def __len__(self):
        return len(self.img_list)
    
    def __getitem__(self, index):
        img, anno_class = self.pull_item(index)
        
    def pull_item(self, index):
        """
        画像のTensor形式データ、アノテーションを取得
        """
        image_file_path = self.img_list[index]
        img = Image.open(image_file_path) # (RGB)
        
        anno_file_path = self.anno_list[index]
        anno_class_img = Image.open(anno_file_path) #[h][w][カラーパレット番号]
        
        img, anno_class_img = self.transform(self.phase, img, anno_class_img)
        
        return img, anno_class_img

In [5]:
from datapath_for_PSPNet import make_datapath_list

In [8]:
# 動作確認

color_mean = (0.485, 0.456, 0.406)
color_std = (0.229, 0.224, 0.225)

rootpath = "../data/VOCdevkit/VOC2012/"

train_img_list, train_anno_list, val_img_list, val_anno_list \
= make_datapath_list(rootpath=rootpath)

print(f"train_img_list: {train_img_list}")

# dataset
train_transform = DataTransform(input_size=475, 
                                color_mean=color_mean, 
                                color_std=color_std)
train_dataset = VOCDataset(train_img_list, train_anno_list, 
                           phase='train', transform=train_transform)


val_transform = DataTransform(input_size=475,
                              color_mean=color_mean, 
                              color_std=color_std)
val_dataset = VOCDataset(val_img_list, val_anno_list, 
                         phase='val', transform=val_transform)

# データの取り出し
# print(val_dataset.__getitem__(0)[0].shape)
# print(val_dataset.__getitem__(0)[1].shape)
print(val_dataset.__getitem__(0))


None
