# Data

## Dataset

In [1]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torchvision.transforms import *
from PIL import Image 
import pathlib

class ShapeDataset(Dataset):
    
    def __init__(self, root_dir, transform=transforms.Compose([Resize(256), RandomCrop(224), ToTensor()])):
        self.root_dir = root_dir
        self.transform = transform
        self.__init()
    
    def __init(self):
        self.jpg_files = [f for f in pathlib.Path(self.root_dir).glob('**/*.jpg') 
                          if '.ipynb_checkpoints' not in f.parts]
        
    def __len__(self):
        return len(self.jpg_files)
    
    def __getitem__(self, idx):
        img_path = self.jpg_files[idx]
        
        image = Image.open(img_path)
        
        if self.transform:
            image = self.transform(image)
        
        sample = {'image': image, 'class': img_path.parts[2]}
        
        return sample

In [8]:
dataset = ShapeDataset('./shapes/train')

In [9]:
len(dataset)

30

In [10]:
dataset.jpg_files

[PosixPath('shapes/train/rect/0026.jpg'),
 PosixPath('shapes/train/rect/0029.jpg'),
 PosixPath('shapes/train/rect/0024.jpg'),
 PosixPath('shapes/train/rect/0023.jpg'),
 PosixPath('shapes/train/rect/0021.jpg'),
 PosixPath('shapes/train/rect/0028.jpg'),
 PosixPath('shapes/train/rect/0025.jpg'),
 PosixPath('shapes/train/rect/0022.jpg'),
 PosixPath('shapes/train/rect/0027.jpg'),
 PosixPath('shapes/train/rect/0020.jpg'),
 PosixPath('shapes/train/poly/0012.jpg'),
 PosixPath('shapes/train/poly/0016.jpg'),
 PosixPath('shapes/train/poly/0011.jpg'),
 PosixPath('shapes/train/poly/0018.jpg'),
 PosixPath('shapes/train/poly/0015.jpg'),
 PosixPath('shapes/train/poly/0017.jpg'),
 PosixPath('shapes/train/poly/0010.jpg'),
 PosixPath('shapes/train/poly/0014.jpg'),
 PosixPath('shapes/train/poly/0013.jpg'),
 PosixPath('shapes/train/poly/0019.jpg'),
 PosixPath('shapes/train/circle/0008.jpg'),
 PosixPath('shapes/train/circle/0007.jpg'),
 PosixPath('shapes/train/circle/0006.jpg'),
 PosixPath('shapes/train/cir

In [11]:
for i in range(len(dataset)):
    sample = dataset[i]
    print(i, sample['image'].shape, sample['class'])

0 torch.Size([3, 224, 224]) rect
1 torch.Size([3, 224, 224]) rect
2 torch.Size([3, 224, 224]) rect
3 torch.Size([3, 224, 224]) rect
4 torch.Size([3, 224, 224]) rect
5 torch.Size([3, 224, 224]) rect
6 torch.Size([3, 224, 224]) rect
7 torch.Size([3, 224, 224]) rect
8 torch.Size([3, 224, 224]) rect
9 torch.Size([3, 224, 224]) rect
10 torch.Size([3, 224, 224]) poly
11 torch.Size([3, 224, 224]) poly
12 torch.Size([3, 224, 224]) poly
13 torch.Size([3, 224, 224]) poly
14 torch.Size([3, 224, 224]) poly
15 torch.Size([3, 224, 224]) poly
16 torch.Size([3, 224, 224]) poly
17 torch.Size([3, 224, 224]) poly
18 torch.Size([3, 224, 224]) poly
19 torch.Size([3, 224, 224]) poly
20 torch.Size([3, 224, 224]) circle
21 torch.Size([3, 224, 224]) circle
22 torch.Size([3, 224, 224]) circle
23 torch.Size([3, 224, 224]) circle
24 torch.Size([3, 224, 224]) circle
25 torch.Size([3, 224, 224]) circle
26 torch.Size([3, 224, 224]) circle
27 torch.Size([3, 224, 224]) circle
28 torch.Size([3, 224, 224]) circle
29 tor

## ImageFolder

In [14]:
from torchvision import datasets

transform = transforms.Compose([Resize(256), RandomCrop(224), ToTensor()])
image_folder = datasets.ImageFolder('./shapes/train', transform=transform)

In [15]:
image_folder.classes

['circle', 'poly', 'rect']

In [25]:
for clazz in image_folder.classes:
    print(clazz, image_folder.class_to_idx[clazz])

circle 0
poly 1
rect 2


## DataLoader

In [12]:
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=4)

In [13]:
for i_batch, sample_batched in enumerate(dataloader):
    print(i_batch, sample_batched['image'].size())

0 torch.Size([2, 3, 224, 224])
1 torch.Size([2, 3, 224, 224])
2 torch.Size([2, 3, 224, 224])
3 torch.Size([2, 3, 224, 224])
4 torch.Size([2, 3, 224, 224])
5 torch.Size([2, 3, 224, 224])
6 torch.Size([2, 3, 224, 224])
7 torch.Size([2, 3, 224, 224])
8 torch.Size([2, 3, 224, 224])
9 torch.Size([2, 3, 224, 224])
10 torch.Size([2, 3, 224, 224])
11 torch.Size([2, 3, 224, 224])
12 torch.Size([2, 3, 224, 224])
13 torch.Size([2, 3, 224, 224])
14 torch.Size([2, 3, 224, 224])


In [16]:
dataloader = DataLoader(image_folder, batch_size=2, shuffle=True, num_workers=4)

In [18]:
for i_batch, sample_batched in enumerate(dataloader):
    print(i_batch, sample_batched.size())

AttributeError: 'list' object has no attribute 'size'

In [21]:
dir(dataloader)

['_DataLoader__initialized',
 '_DataLoader__multiprocessing_context',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__iter__',
 '__le__',
 '__len__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_auto_collation',
 '_dataset_kind',
 '_index_sampler',
 'batch_sampler',
 'batch_size',
 'collate_fn',
 'dataset',
 'drop_last',
 'multiprocessing_context',
 'num_workers',
 'pin_memory',
 'sampler',
 'timeout',
 'worker_init_fn']