### Custom Dataset Class (2)

Create by Xi in 2022/3/30.</br>
Last update by Xi in 2022/4/8.

When all images are stored together in a single folder, and filename indicates the class.</br>
Structure like the following:
```
Root
└── Images
    ├── class01-image01.jpg
    ├── class01-image02.jpg
    ├── class02-image01.jpg
    ├── class02-image02.jpg
    └── ...
```

In [1]:
import glob
import numpy
from pathlib import Path

import torch
from torch.utils.data import Dataset, DataLoader
import torchvision

In [2]:
# Change this for your own directory.
ROOT_DIR = Path("/root/jupyter_projects")  # Your working directory.
DATASET_DIR = ROOT_DIR / "Data"  # Folder to store dataset.

data_path = DATASET_DIR / "Caltech101"

In [3]:
path_list = []  # Store the path of images.
class_list = []  # Store the class name.

In [4]:
from pathlib import Path

pathlist = Path(data_path).glob('*')

for path in pathlist:
    # Because path is object not string.
    path_str = str(path.name)  # Class name.
    class_list.append(path_str)
    
    sub_path_list = Path(path).glob('*')
    
    for sub_path in sub_path_list:
        sub_path_str = str(sub_path)
        path_list.append(sub_path_str)

In [5]:
print(len(path_list))  # Total number of images.
print(len(class_list))  # Total number of class.

9145
102


In [6]:
idx_to_class = {i:j for i, j in enumerate(class_list)}  # Add index to class.
class_to_idx = {value:key for key,value in idx_to_class.items()}  # Convert class name to index.

In [7]:
# Have a look at idx_to_class.
for i in range(5):
    print(i, idx_to_class[i])

0 Faces_easy
1 crocodile
2 pyramid
3 stegosaurus
4 bonsai


In [8]:
# Test if is work correct.
print(class_to_idx["bonsai"])

4


In [9]:
class Caltech101Dataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform
        
    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        
        image_filepath = self.image_paths[idx]
        
        # For numpy array:
        # image = cv2.imread(image_filepath)
        # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # For torch tensor:
        image = torchvision.io.read_image(image_filepath)
        
        label = image_filepath.split('/')[-2]  # Get class name from the path.
        label = class_to_idx[label]  # Convert class name to index.
        
        if self.transform is not None:
            image = self.transform(image=image)["image"]
        
        return image, label

In [10]:
Caltech101 = Caltech101Dataset(path_list)

In [11]:
for X, y in Caltech101:
    print(type(X))
    print(X.shape)  # (channel, width, height)
    
    print(type(y))
    print(y)
    break

<class 'torch.Tensor'>
torch.Size([3, 313, 256])
<class 'int'>
0
