In [11]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torchvision.transforms import Resize, ToTensor, Normalize
from PIL import Image
import numpy as np
import torch
import matplotlib.pyplot as plt



import import_ipynb
from getLabeling import label, all_train_data


from albumentations import *
from albumentations.pytorch import ToTensorV2

In [12]:
class TrainDataset(Dataset):
    def __init__(self, img_paths, transform, label):
        self.img_paths = img_paths
        self.transform = transform
        self.label = label
        
        
    def set_transform(self, transform):
        """
        transform 함수를 설정하는 함수입니다.
        """
        self.transform = transform
    
    
    def __getitem__(self, index):
        image = Image.open(self.img_paths[index])
        
        if self.transform:
            image_transform = self.transform(image=np.array(image))['image']
            
        y = np.zeros(18)    
        y[int(self.label[index])] = 1
        return image_transform, torch.LongTensor(y)

    def __len__(self):
        return len(self.img_paths)

In [13]:
# train_transform = transforms.Compose([
#     #transforms.RandomCrop(32, padding=4),
#     #transforms.RandomHorizontalFlip(),
#     Resize((224, 224), Image.BILINEAR),
#     ToTensor(),
#     Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))
# ])

# train_dataset = TrainDataset(all_train_data, train_transform, label)
# dataset = DataLoader(train_dataset, batch_size = 16, shuffle = True, num_workers = 12)  # shuffle을 True로 하니까 loss값이 줄어듦


In [14]:
# n_val = int(len(dataset) * 0.3)
# n_train = len(dataset) - n_val
# train_dataset, val_dataset = torch.utils.data.random_split(dataset, [n_train, n_val])

In [15]:
# loader_train = DataLoader(
#     train_dataset,
#     batch_size = 16,
#     num_workers = 12,
#     shuffle = True
# )

# loader_val = DataLoader(
#     val_dataset,
#     batch_size = 16,
#     num_workers = 12,
#     shuffle = False
# )

In [16]:
def get_transform(need=('train', 'val'), img_size = (224, 224), mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225)):
    
    transformations = {}
    if 'train' in need:
        transformations['train'] = Compose([
            Resize(img_size[0], img_size[1], p=1.0),
            HorizontalFlip(p=0.5),
            ShiftScaleRotate(p=0.5),
            #HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),
            #RandomBrightnessContrast(brightness_limit=(-0.1, 0.1), contrast_limit=(-0.1, 0.1), p=0.5),
            GaussNoise(p=0.5),
            Normalize(mean=mean, std=std, max_pixel_value=255.0, p=1.0),
            ToTensorV2(p=1.0),
        ], p=1.0)
    if 'val' in need:
        transformations['val'] = Compose([
            Resize(img_size[0], img_size[1]),
            Normalize(mean=mean, std=std, max_pixel_value=255.0, p=1.0),
            ToTensorV2(p=1.0),
        ], p=1.0)
    return transformations

In [17]:
transform = get_transform()
dataset =  TrainDataset(all_train_data, transform, label)

In [18]:
n_val = int(len(dataset) * 0.3)
n_train = len(dataset) - n_val
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [n_train, n_val])

In [19]:
train_dataset.dataset.set_transform(transform['train'])
val_dataset.dataset.set_transform(transform['val'])

In [20]:
loader_train = DataLoader(
    train_dataset,
    batch_size = 4,
    num_workers = 4,
    shuffle = True
)

loader_val = DataLoader(
    val_dataset,
    batch_size = 4,
    num_workers = 4,
    shuffle = False
)