In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [16]:
import glob
import os
from torch.utils.data import Dataset
import cv2
import torch


def load_img(data_dir, labels):
  img_list = []
  for idx in range(len(labels)):
    path = data_dir + labels[idx]
    for img in os.listdir(path):
      img_list.append([path+"/"+img, idx])
  return img_list


class RuruDataset(Dataset):
  def __init__(self, data_dir, labels, transform=None):
    self.files_path = load_img(data_dir, labels)  # 함수
    self.transform = transform

  def __len__(self):
    return len(self.files_path)

  def __getitem__(self, idx):
    img_path = self.files_path[idx][0]
    image = cv2.imread(img_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    label = self.files_path[idx][1]

    if self.transform:
      image = self.transform(image)
      label = torch.Tensor([label]).long()

    return {"image": image, "label": label}

In [21]:
from torchvision import transforms

transformer = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224)),
    transforms.Normalize(mean=[0.5, 0.5, 0.5],
                         std=[0.5, 0.5, 0.5])
])

In [20]:
from torch.utils.data import DataLoader

train_data_dir = "/content/drive/MyDrive/한끼만/ruru/train/" # 데이터 경로 바꾸기
test_data_dir = "/content/drive/MyDrive/한끼만/ruru/test/" # 데이터 경로 바꾸기
labels = sorted(os.listdir(train_data_dir), key= lambda x: len(x))

train_dataset = RuruDataset(train_data_dir, labels, transformer) # train 데이터셋만 불러오기
train = DataLoader(train_dataset, batch_size=4, shuffle=True, drop_last=True)
test_dataset = RuruDataset(test_data_dir, labels, transformer)
test = DataLoader(test_dataset, batch_size=4, shuffle=True, drop_last=False)