In [1]:
import torch
import os 
import glob
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

In [2]:
def is_grayscale(img):
    return img.mode == 'L'

In [8]:
class CustomImageDataset(Dataset):
    def __init__(self, image_paths, transform):
        self.image_paths = glob.glob(os.path.join(image_paths, '*', '*', '*.jpg'))       # 이미지 폴더 내 jpg파일만 호출
        self.transform = transform
        self.label_dict = {'dew':0, 'fogsmog':1, 'frost':2, 'glaze':3, 'hail':4, 'lightning':5, 'rain':6, 'rainbow':7,
                          'rime':8, 'sandstorm':9, 'snow':10}
        
        self.cache()
        
    def __getitem__(self,index):
        if index in self.cache:
            img, label = self.cache[index]
        else: 
            image_path = self.image_paths[index]
            print(image_path)
            img = Image.open(image_path).convert('RGB')

            if not is_grayscale(img):
                folder_name = image_path.split('\\')
                folder_name = folder_name[2]

                label = self.label_dict[folder_name]
                
                self.cache[index] = (img,label)

            else:
                print('흑백이미지 >>>', image_path)
                return None, None
                
        if self.transform:
            img = self.transform(img)

        return img, label

        
    def __len__(self):
        return len(self.image_paths)

In [4]:
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor()
])

In [5]:
image_paths = './data/sample_data_01/'
dataset = CustomImageDataset(image_paths, transform=transform)

data_loader = DataLoader(dataset, 32, shuffle=True)