In [1]:
import glob
import os.path as osp
import random
import numpy as np
from PIL import Image

import torch
from torchvision import transforms
import torch.utils.data as data

  from .autonotebook import tqdm as notebook_tqdm


In [2]:

def make_data_path(root):
    """
    make data path list

    :return:
    path_list: list
    """
    root_path = root + 'images\\'
    ids_labels = np.load(root + 'id_label.npy', allow_pickle=True)

    path_list = []
    # glob -> load file path of sub directory
    for id in ids_labels[:,0]:
        path_list.append(root_path + str(id) + '.jpg')

    return path_list, list(ids_labels)


In [3]:
class ImageTransform:
    """
    image pre-processing: resize image, normalization RGB value
    version : train, validation
    * train_version : image data augmentation

    Attributes
    ----------
    resize: int
    mean : (R, G, B)
    std : (R, G, B)
    """

    def __init__(self, resize, mean, std):
        self.data_transform = {
            'train': transforms.Compose([
                transforms.RandomResizedCrop(resize, scale=(0.5, 1.0)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ]),
            'val': transforms.Compose([
                transforms.Resize(resize),
                transforms.CenterCrop(resize),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ])
        }

    def __call__(self, img, phase='train'):
        """
        :param img: image data
        :param phase: 'train' or 'val' -> specify dataset mode
        :return: self.data_transform
        """
        return self.data_transform[phase](img)

In [4]:
class ProductDataset(data.Dataset):
    """
    ant, bee image Dataset class. Dataset class 상속

    Attributes
    ----------
    file_list : list
        -> file path list
    transform : object
        -> data pre-processing instance
    phase : 'train' or 'val'

    """

    def __init__(self, file_list, transform=None, phase="train"):
        self.file_list, self.labels = file_list
        self.transform = transform
        self.phase = phase

    def __len__(self):
        """return length of images"""
        return len(self.file_list)

    def __getitem__(self, idx):
        """
        get the Tensor and label of pre-processed image
        :param idx: index of data
        :return:
        """

        # load image
        img_path = self.file_list[idx]
        img_id, label = self.labels[idx]
        img = Image.open(img_path)

        # pre-processing image data
        img_transformed = self.transform(img, self.phase)  # torch.Size([3, 224, 224])

        return img_transformed, label


In [6]:
import os
from dotenv import load_dotenv

load_dotenv()
root = os.environ.get("ROOT")

train_list = make_data_path(root)
val_list = make_data_path(root)
size = 224
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
train_dataset = ProductDataset(file_list=train_list, transform=ImageTransform(size, mean, std), phase='train')
val_dataset = ProductDataset(file_list=val_list, transform=ImageTransform(size, mean, std))
# batch size
batch_size = 32
# data loader
train_dataLoader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataLoader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
dataLoaders_dict = {"train": train_dataLoader, "val": val_dataLoader}

# test
batch_iterator = iter(dataLoaders_dict["train"]) # 반복이 가능한 iterator로 변환
inputs, labels = next(batch_iterator) # 첫번쨰 요소 추출
print(inputs.size())
print(labels)

torch.Size([32, 3, 224, 224])
('Tshirts', 'Handbags', 'Wallets', 'Wallets', 'Clutches', 'Handbags', 'Tshirts', 'Sports Shoes', 'Kurta Sets', 'Sports Shoes', 'Watches', 'Nightdress', 'Kurtas', 'Shirts', 'Tshirts', 'Shirts', 'Tshirts', 'Tshirts', 'Dresses', 'Tshirts', 'Kurtas', 'Perfume and Body Mist', 'Tshirts', 'Flats', 'Tshirts', 'Casual Shoes', 'Ties', 'Sports Shoes', 'Tshirts', 'Bra', 'Tshirts', 'Handbags')
