## Dataset

**training data를 loading 하기 위한 class**

#### dataset에서 하는 일
- training data가 저장된 directory에서 `data 읽어 들이기`
- train 하기 위한 data로 만들어주기
- `data augmentation`

#### dataset class function
+ __input__:  전체 input feature tensor와 target tensor
<br>
<br>

+ __init__(self): 
    + 필요한 변수 선언. 전체 x_data와 y_data load하거나 파일목록을 load
    + training data가 저장되어 있는 disk location으로부터 training 할 image file 이름을 불러오기
    + data augmentation을 위해 필요한 값들 설정  
<br>

+ __get_item__(self, index): 
    + idx에 해당하는 training data(torch.tensor)를 return
    + 모든 return 값들의 tensor shape이 같아야 함. collate_fn()때문
    + 학습 전 preprocessing을 진행
    + data augmentation 진행  
<br>

+ __len__(self): 
    + len은 training 하기 위한 전체 dataset의 길이를 return
    + 여기서 정의된 len에 따라 한 epoch동안 training할 dataset의 길이가 결정 -> mini batch가 만들어짐

In [2]:
import torch
import torch.utils.data as data
import torch.distributed as dist
import pandas as pd
import numpy as np
import os
import torchvision.transforms.functional as TF
from skimage import io
from torchvision import transforms, utils
import random

In [3]:
#기본 Dataset class
class BasicDataset(data.Dataset):
    def __init__(self, x_tensor, y_tensor):
        super(BasicDataset, self).__init__()

        self.x = x_tensor
        self.y = y_tensor
        
    def __getitem__(self, index):
        return self.x[index], self.y[index]

    def __len__(self):
        return len(self.x)

In [63]:
class KITTIDataset(data.Dataset):
    def __init__(self, path, mode='train'):
        '''
        mode = 'train' or 'eval'
        '''

        super(KITTIDataset, self).__init__()

        self.path = path
        self.mode = mode

        if self.mode == 'train':
            self.dir_name = 'training'
        else:
            self.dir_name = 'testing'

        self.image_path = os.path.join(self.path,self.dir_name , 'image_2')
        self.seg_path = os.path.join(self.path,self.dir_name , 'semantic_rgb')
        self.instance_path = os.path.join(self.path,self.dir_name , 'instance')

        self.image_files = os.listdir(self.image_path) 
        # 그 외 학습에 필요한 데이터 정보들...

    def __len__(self):
        return len(self.image_files)
    
    def augment(self, image,masks):
        """
        Applying the same augmentation to
        image and its corresponding mask
        Args:
            image(PIL Image): resized image(new_width,new_height)
            masks(PIL Image): resized mask(new_width,new_height)
        """

        # Random horizontal flipping
        if random.random() > 0.5:
            image = TF.hflip(image)
            masks = TF.hflip(masks)

        # Random vertical flipping
        if random.random() > 0.5:
            image = TF.vflip(image)
            masks = TF.vflip(masks)
        return image,masks

    def transform(self,image,masks,aug):
        """Applying a set of  transformations as a datapreprocessing task"""
        # convert to PIL Image.
        PIL_convert = transforms.ToPILImage()
        image = PIL_convert(image)
        masks = PIL_convert(masks.astype(np.int32))
        # resize the image and masks
        resize = transforms.Resize(size=(512,512))
        image = resize(image)
        masks = resize(masks)
        # augmentation
        if aug is True:
            self.augment(image,masks)
        else:
            pass
        # Convert to Tensor
        image = TF.to_tensor(image)
        masks = TF.to_tensor(masks)

        return image,masks

    def __getitem__(self,image_id):
        """
        Function to read the image and mask
        and return a sample of dataset when neededself.
        Args:
        image_id: image index to iterate over the dataset samples
        Returns:
        sample(dict): a sample of the dataset
        """
        # read the image
        image_path  = (os.path.join(self.image_path,self.image_files[image_id]))
        image = io.imread(image_path)
        # read the mask
        # mask_dir = os.path.join(self.seg_path,self.image_files[image_id])
        mask_dir = os.path.join(self.seg_path, self.image_files[image_id])
        masks_list = []
        for i, f in enumerate (next(os.walk(mask_dir))[2]):
            if f.endswith ('.png'):
                m = io.imread(os.path.join(mask_dir,f)).astype(np.bool)
                m = m[:,:,0]
                masks_list.append(m)
                #combine all the masks corresponding of an invidual sample image into single binary mask
                if len(masks_list) != 1:
                    masks = np.logical_or(masks,masks_list[i])
                else:
                    masks = masks_list[i]
        # do the transforms..
        trans_img,trans_masks = self.transform(image,masks,self.aug)
        sample = {"image":trans_img,"masks":trans_masks}

        return(sample)

In [64]:
if __name__ == "__main__":
    dataset = KITTIDataset("../data_semantics")
    print(len(dataset))
    print(dataset[0])

200


StopIteration: 