# Data

## Dataset

In [37]:
from torch.utils.data import Dataset, DataLoader
from skimage import io, transform
import pathlib

class ShapeDataset(Dataset):
    
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.__init()
    
    def __init(self):
        self.jpg_files = [f for f in pathlib.Path(self.root_dir).glob('**/*.jpg') 
                          if '.ipynb_checkpoints' not in f.parts]
        
    def __len__(self):
        return len(self.jpg_files)
    
    def __getitem__(self, idx):
        img_path = self.jpg_files[idx]
        
        image = io.imread(img_path)
        sample = {'image': image, 'class': img_path.parts[2]}
        return sample

In [32]:
dataset = ShapeDataset('./shapes/train/circle')

In [33]:
len(dataset)

10

In [34]:
dataset.jpg_files

[PosixPath('shapes/train/circle/0008.jpg'),
 PosixPath('shapes/train/circle/0007.jpg'),
 PosixPath('shapes/train/circle/0006.jpg'),
 PosixPath('shapes/train/circle/0005.jpg'),
 PosixPath('shapes/train/circle/0009.jpg'),
 PosixPath('shapes/train/circle/0001.jpg'),
 PosixPath('shapes/train/circle/0000.jpg'),
 PosixPath('shapes/train/circle/0002.jpg'),
 PosixPath('shapes/train/circle/0004.jpg'),
 PosixPath('shapes/train/circle/0003.jpg')]

In [36]:
for i in range(len(dataset)):
    sample = dataset[i]
    print(i, sample['image'].size, sample['image'].shape, sample['class'])

0 62208 (144, 144, 3) circle
1 62208 (144, 144, 3) circle
2 62208 (144, 144, 3) circle
3 62208 (144, 144, 3) circle
4 62208 (144, 144, 3) circle
5 62208 (144, 144, 3) circle
6 62208 (144, 144, 3) circle
7 62208 (144, 144, 3) circle
8 62208 (144, 144, 3) circle
9 62208 (144, 144, 3) circle


## Data loader

In [38]:
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=4)

In [39]:
for i_batch, sample_batched in enumerate(dataloader):
    print(i_batch, sample_batched['image'].size)

TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/root/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
    data = fetcher.fetch(index)
  File "/root/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
    return self.collate_fn(data)
  File "/root/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py", line 74, in default_collate
    return {key: default_collate([d[key] for d in batch]) for key in elem}
  File "/root/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py", line 74, in <dictcomp>
    return {key: default_collate([d[key] for d in batch]) for key in elem}
  File "/root/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py", line 81, in default_collate
    raise TypeError(default_collate_err_msg_format.format(elem_type))
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'imageio.core.util.Array'>


In [15]:
from torchvision import datasets

image_folder = datasets.ImageFolder('./shapes/train', transform=None)

In [16]:
image_folder.classes

['circle', 'poly', 'rect']