# 준비

In [12]:
import torch
import random
import numpy as np
import os

seed = 50
os.environ["PYTHONHASHSEED"] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.enabled = False

In [18]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [19]:
device

device(type='cpu')

In [20]:
import pandas as pd

labels = pd.read_csv("data/train.csv")
submission = pd.read_csv("data/sample_submission.csv")

In [21]:
from sklearn.model_selection import train_test_split

train, valid = train_test_split(labels, test_size=0.1, stratify=labels["has_cactus"], random_state=50)

In [22]:
len(train), len(valid)

(15750, 1750)

## 데이터셋 형식화

In [24]:
import cv2
from torch.utils.data import Dataset

In [27]:
class ImageDataset(Dataset):
    
    def __init__(self, df, img_dir="./",transform=None):
        super().__init__()
        self.df = df
        self.img_dir = img_dir
        self.transform = transform
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        img_id = self.df.iloc[idx, 0]
        img_path = self.img_dir + img_id
        image = cv2.imread(img_path)
        image = cv2.cv2Color(image, cv2.COLOR_BGR2RGB)
        label = self.df.iloc[idx,1]
        
        if self.transform is not None:
            image = self.transform(image)
            
        return image, label

## 데이터셋 생성

In [29]:
from torchvision import transforms # 이미지 변환 모듈

transform = transforms.ToTensor()

dataset_train = ImageDataset(df=train, img_dir="train/", transform=transform)
dataset_valid = ImageDataset(df=valid, img_dir="train/", transform=transform)

## 데이터 로더 생성

In [30]:
from torch.utils.data import DataLoader

loader_train = DataLoader(dataset=dataset_train, batch_size=32, shuffle=True)
loader_valid = DataLoader(dataset=dataset_valid, batch_size=32, shuffle=True)