In [1]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import glob
from tqdm import tqdm, tqdm_notebook
import os
import pandas as pd

from torchvision import transforms
from torchvision.transforms import Resize, ToTensor, Normalize

## Training Dataset 정의

In [2]:
class TrainDataset(Dataset):
    def __init__(self, data_root, transform):
        self.data_root = data_root
        self.transform = transform
        self.imag_list = self._load_img_list(data_root)

    def __getitem__(self, index):
        img_path = self.imag_list[index]
        img = Image.open(img_path)
        
        if self.transform:
            img = self.transform(img)

        # Ground Truth
        label = self._get_class_idx_from_img_name(img_path)

        return img, label

    def __len__(self):
        return len(self.imag_list)
    
    def _load_img_list(self, data_root):
        img_list = []
        image_dir = os.path.join(data_root, 'images')
        
        for dir in glob.glob(image_dir + '/*'):
            img_list.extend(glob.glob(dir+'/*'))

        return img_list

    def _load_img_ID(self, img_path):
        return img_path.split('/')[7].split('_')[0]

    def _get_class_idx_from_img_name(self, img_path):
        img_name = os.path.basename(img_path)
        img_id = self._load_img_ID(img_path)
        
        img_idx = train_data.loc[train_data['id'] == img_id].index
        v = train_data.iloc[img_idx]['age+gender'].values[0]
        if 'normal' in img_name:
            return 12 + v
        elif 'incorrect_mask' in img_name:
            return 6 + v
        else:
            return 0 + v

In [3]:
train_dir = '/opt/ml/input/data/train'
# meta 데이터와 이미지 경로를 불러옵니다.
train_data = pd.read_csv(os.path.join(train_dir, 'train_1.csv'))

# Test Dataset 클래스 객체를 생성하고 DataLoader를 만듭니다.
transform = transforms.Compose([
    Resize((512, 384), Image.BILINEAR),
    ToTensor(),
    Normalize(mean=(0.5, 0.5, 0.5), std=(0.2, 0.2, 0.2)),
])


dataset = TrainDataset(train_dir, transform)

In [4]:
len(dataset)

18900

In [5]:
dataset[0]

(tensor([[[-0.5980, -0.6569, -0.6569,  ...,  1.4608,  1.4608,  1.4608],
          [-0.6569, -0.6765, -0.6961,  ...,  1.4608,  1.4608,  1.4608],
          [-0.6961, -0.7157, -0.7157,  ...,  1.4608,  1.4608,  1.4608],
          ...,
          [-2.0098, -1.8333, -1.4804,  ...,  1.2451,  1.2451,  1.2255],
          [-2.0098, -2.0490, -2.0098,  ...,  1.2451,  1.2255,  1.2255],
          [-2.1471, -2.1863, -2.1275,  ...,  1.2451,  1.2255,  1.2255]],
 
         [[ 0.6765,  0.6569,  0.6569,  ...,  1.5784,  1.5784,  1.5784],
          [ 0.6765,  0.6373,  0.6569,  ...,  1.5784,  1.5784,  1.5784],
          [ 0.6373,  0.5980,  0.6373,  ...,  1.5784,  1.5784,  1.5784],
          ...,
          [-2.0882, -2.1471, -2.1275,  ...,  1.3431,  1.3431,  1.3235],
          [-2.0294, -2.1275, -2.2059,  ...,  1.3431,  1.3235,  1.3235],
          [-2.2255, -2.0882, -2.0294,  ...,  1.3431,  1.3235,  1.3235]],
 
         [[ 0.7549,  0.7745,  0.7745,  ...,  1.5392,  1.5392,  1.5392],
          [ 0.7353,  0.7549,

In [6]:
dataset[0][0].size()

torch.Size([3, 512, 384])

##  Dataloader 정의

In [7]:
loader = DataLoader(
    dataset,
    batch_size = 32,
    shuffle=False
)

In [8]:
from collections import defaultdict

In [11]:
class_counter = defaultdict(int)

for iter, (img, label) in tqdm(enumerate(loader)):
    for value in label:
        class_counter[value.item()] += 1

591it [03:34,  2.75it/s]


In [12]:
class_counter

defaultdict(int,
            {4: 4085,
             16: 817,
             10: 817,
             1: 2050,
             13: 410,
             7: 410,
             3: 3660,
             15: 732,
             9: 732,
             0: 2745,
             12: 549,
             6: 549,
             5: 545,
             17: 109,
             11: 109,
             2: 415,
             14: 83,
             8: 83})

In [18]:
import pandas as pd

In [28]:
pd.DataFrame.from_dict(dict(class_counter), orient = 'index').sort_index().T

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17
0,2745,2050,415,3660,4085,545,549,410,83,732,817,109,549,410,83,732,817,109
