In [26]:
# 현재 OS 및 라이브러리 버전 체크 체크
current_os = platform.system()
print(f"Current OS: {current_os}")
print(f"CUDA: {torch.cuda.is_available()}")
print(f"Python Version: {platform.python_version()}")
print(f"torch Version: {torch.__version__}")
print(f"torchvision Version: {torchvision.__version__}")

Current OS: Linux
CUDA: True
Python Version: 3.8.5
torch Version: 1.9.0+cu102
torchvision Version: 0.8.2


In [3]:
import os
import sys
import gzip
import random
import platform
import warnings
import collections
from tqdm import tqdm, tqdm_notebook

In [4]:
import re
import requests
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

In [15]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms, utils
from torchvision.io import read_image

In [6]:
from torch.utils.data import Dataset, DataLoader, random_split, SubsetRandomSampler, WeightedRandomSampler

In [7]:
# 테스트 데이터셋 폴더 경로를 지정해주세요.
test_dir = '/opt/ml/input/data/train'

In [8]:
# meta 데이터와 이미지 경로를 불러옵니다.
TRAIN_MY_PATH = {
    'trainCsv' : os.path.join(test_dir, 'train.csv'),
    'image' : os.path.join(test_dir, 'images')
}

In [53]:
class MyTrainDataset(Dataset) :
    def __init__(self, path, transform, train=True):
        self.img_data = pd.read_csv(path['trainCsv'])
        self.img_dir = path['image']
        self.classes = ['id', 'gender', 'race', 'age']
        
        self.train = train
        self.transform = transform
        self.path = path
        self._repr_indent = 4
        
    def __len__(self) :
        return len(self.img_data)
    
    def __getitem__(self, idx) :
        person_path = os.path.join(self.img_dir, self.img_data.iloc[idx,4])
        incorrect_mask_img_path = os.path.join(person_path, 'incorrect_mask.jpg')
        print(incorrect_mask_img_path)
        image = Image.open(incorrect_mask_img_path)
        id = self.img_data.iloc[idx, 0]
        if self.transform :
            image = self.transform(image)
        return image, id
    
    def __repr__(self):
        '''
        https://github.com/pytorch/vision/blob/master/torchvision/datasets/vision.py
        '''
        head = "(Inform) My Custom Dataset"
        data_path = self._repr_indent*" " + "Data path: {}".format(self.path['image'])
        label_path = self._repr_indent*" " + "Label path: {}".format(self.path['trainCsv'])
        num_data = self._repr_indent*" " + "Number of datapoints: {}".format(self.__len__())
        num_classes = self._repr_indent*" " + "Number of classes: {}".format(len(self.classes))

        return '\n'.join([head,
                          data_path, label_path, 
                          num_data, num_classes])

In [54]:
dataset_train_My = MyTrainDataset(path = TRAIN_MY_PATH,
                                  transform = transforms.ToTensor(),
                                  train = True)

In [19]:
dataset_train_My

(Inform) My Custom Dataset
    Data path: /opt/ml/input/data/train/images
    Label path: /opt/ml/input/data/train/train.csv
    Number of datapoints: 2700
    Number of classes: 4

In [20]:
len(dataset_train_My)

2700

In [55]:
image, label = next(iter(dataset_train_My))
image, label

/opt/ml/input/data/train/images/000001_female_Asian_45/incorrect_mask.jpg


(tensor([[[0.7490, 0.7490, 0.7490,  ..., 0.7882, 0.7882, 0.7882],
          [0.7490, 0.7490, 0.7490,  ..., 0.7882, 0.7882, 0.7882],
          [0.7490, 0.7490, 0.7490,  ..., 0.7882, 0.7882, 0.7882],
          ...,
          [0.5843, 0.5882, 0.5882,  ..., 0.5922, 0.5922, 0.5922],
          [0.5725, 0.5725, 0.5725,  ..., 0.5961, 0.5961, 0.5961],
          [0.5608, 0.5608, 0.5608,  ..., 0.6078, 0.6078, 0.6078]],
 
         [[0.7451, 0.7451, 0.7451,  ..., 0.7843, 0.7843, 0.7843],
          [0.7451, 0.7451, 0.7451,  ..., 0.7843, 0.7843, 0.7843],
          [0.7451, 0.7451, 0.7451,  ..., 0.7843, 0.7843, 0.7843],
          ...,
          [0.3804, 0.3843, 0.3843,  ..., 0.3686, 0.3686, 0.3686],
          [0.3686, 0.3686, 0.3686,  ..., 0.3725, 0.3725, 0.3725],
          [0.3569, 0.3569, 0.3569,  ..., 0.3686, 0.3686, 0.3686]],
 
         [[0.7255, 0.7255, 0.7255,  ..., 0.7647, 0.7647, 0.7647],
          [0.7255, 0.7255, 0.7255,  ..., 0.7647, 0.7647, 0.7647],
          [0.7255, 0.7255, 0.7255,  ...,